from dataclasses import dataclass
from typing import Any, NamedTuple, Protocol, TypeVar
import jax

import jax.numpy as jnp
from jaxtyping import PyTree

from nix.optax_ext import make_schedule
from nix.utils.jax_utils import pmean_if_pmap

DampingState = PyTree
S = TypeVar('S', bound=DampingState)
T = TypeVar('T', contravariant=True)


class Damper(Protocol[T, S]):
    def init(self) -> S: ...

    def update(self, args: T, state: S) -> tuple[jax.Array, S]: ...


class DampingScheduleState(NamedTuple):
    step: jax.Array


@dataclass(frozen=True)
class DampingSchedule(Damper[None, DampingScheduleState]):
    scheduler: str
    scheduler_args: dict[str, dict[str, Any]]

    def init(self):
        return DampingScheduleState(jnp.zeros(()))

    def update(self, args: None, state: DampingScheduleState):
        step = state.step + 1
        schedule = make_schedule(self.scheduler, **self.scheduler_args)
        return schedule(step), DampingScheduleState(step)


class StdDampingState(NamedTuple):
    damping: jax.Array


@dataclass(frozen=True)
class StdBasedDamping(Damper[jax.Array, StdDampingState]):
    init_value: float
    base: float
    axis: int | tuple[int, ...] | None
    target_power: float = 0.5
    decay: float = 0.999
    max_damp: float = 1e-1

    def init(self):
        return StdDampingState(jnp.full((), self.init_value))

    def update(self, args: jax.Array, state: StdDampingState):
        var = args.var(axis=self.axis).mean()
        damping = state.damping
        target = pmean_if_pmap(self.base * jnp.power(jnp.sqrt(var), self.target_power))
        damping = jnp.where(damping < target, damping / self.decay, damping)
        damping = jnp.where(damping > target, self.decay * damping, damping)
        damping = jnp.clip(damping, 1e-8, self.max_damp)
        return damping, StdDampingState(damping)
