import nix.linalg as linalg
from .natural_gradient import NaturalGradientPreconditioner, NaturalGradientState
from .optax_ext import decay_to_init, prodigy, scale_by_trust_ratio_embeddings
from .utils import tree_utils
from .utils.jax_utils import (
    broadcast,
    instance,
    jit,
    p_split,
    pgather_if_pmap,
    pmap,
    pmax_if_pmap,
    pmean_if_pmap,
    pmin_if_pmap,
    psum_if_pmap,
    replicate,
)
from .utils.moving_average import EMA, Average, MovingAverage

__all__ = [
    'EMA',
    'Average',
    'MovingAverage',
    'decay_to_init',
    'prodigy',
    'scale_by_trust_ratio_embeddings',
    'tree_utils',
    'broadcast',
    'instance',
    'jit',
    'p_split',
    'pgather_if_pmap',
    'pmap',
    'pmax_if_pmap',
    'pmean_if_pmap',
    'pmin_if_pmap',
    'psum_if_pmap',
    'replicate',
    'NaturalGradientPreconditioner',
    'NaturalGradientState',
    'linalg',
]
