import jax
import jax.flatten_util as jfu
import jax.numpy as jnp
import jax.tree_util as jtu
from typing import Any, Callable, NamedTuple, Protocol, TypeVarTuple, TypeVar, Generic
from dataclasses import dataclass
from nix.natural_gradient.damping import Damper, DampingState

from .cg import cg
from ..utils.jax_utils import pmean_if_pmap, pall_to_all, psum_if_pmap, pgather, pidx
from ..utils.tree_utils import tree_add, tree_mul, tree_sub, tree_squared_norm


T = TypeVar('T')
Ts = TypeVarTuple('Ts')
D = TypeVar('D')
S = TypeVar('S', bound=DampingState)


class NaturalGradientState(NamedTuple, Generic[T, S]):
    last_grad: T
    damper_state: S


class VectorizeFunction(Protocol[T, *Ts]):
    def __call__(self, fn: Callable[[*Ts], T]) -> Callable[[*Ts], T]: ...


class NaturalGradientPreconditioner(Protocol[T, D, S, *Ts]):
    def init(self, params: T): ...

    def precondition(
        self,
        params: T,
        log_p_input: tuple[*Ts],
        damping_input: D,
        dloss_dlog_p: jax.Array,
        state: Any,
    ) -> tuple[T, Any, dict[str, jax.Array]]: ...


@dataclass(frozen=True)
class CGNaturalGradientPreconditioner(NaturalGradientPreconditioner[T, D, S, *Ts]):
    log_p: Callable[[T, *Ts], jax.Array]
    vectorize_fn: VectorizeFunction[jax.Array, T, *Ts]
    damper: Damper[D, S]
    enable: bool = True
    maxiter: int = 100
    precision: str = 'float32'

    def init(self, params: T) -> NaturalGradientState[T, S]:
        return NaturalGradientState(
            last_grad=jtu.tree_map(jnp.zeros_like, params),
            damper_state=self.damper.init(),
        )

    def precondition(
        self,
        params: T,
        log_p_input: tuple[*Ts],
        damping_input: D,
        dloss_dlog_p: jax.Array,
        natgrad_state: NaturalGradientState[T, S],
    ) -> tuple[T, NaturalGradientState[T, S], dict[str, jax.Array]]:
        N = dloss_dlog_p.size

        def log_p_closure(p: T):
            return self.vectorize_fn(self.log_p)(p, *log_p_input)

        _, vjp_fn = jax.vjp(log_p_closure, params)
        gradient = pmean_if_pmap(vjp_fn(dloss_dlog_p / N)[0])

        if not self.enable:
            return gradient, natgrad_state, {'damping': jnp.array(0.0)}

        _, jvp_fn = jax.linearize(log_p_closure, params)

        damping, damping_state = self.damper.update(
            damping_input, state=natgrad_state.damper_state
        )

        def Fisher_matmul(v):
            with jax.default_matmul_precision(self.precision):
                w = jvp_fn(v) / N
                uncentered = vjp_fn(w)[0]
                result = tree_add(uncentered, tree_mul(v, damping))
                result = pmean_if_pmap(result)
                return result

        # Compute natural gradient
        natgrad = cg(
            A=Fisher_matmul,
            b=gradient,
            x0=natgrad_state.last_grad,
            fixed_iter=jax.device_count()
            > 1,  # if we have multiple GPUs we must do a fixed number of iterations
            maxiter=self.maxiter,
        )[0]

        return (
            natgrad,
            NaturalGradientState(natgrad, damping_state),
            {'damping': damping},
        )


@dataclass(frozen=True)
class CGNaturalGradientPreconditioner2(NaturalGradientPreconditioner[T, D, S, *Ts]):
    log_p: Callable[[T, *Ts], jax.Array]
    vectorize_fn: VectorizeFunction[jax.Array, T, *Ts]
    damper: Damper[D, S]
    enable: bool = True
    maxiter: int = 100
    precision: str = 'float32'

    def init(self, params: T) -> NaturalGradientState[T, S]:
        return NaturalGradientState(last_grad=None, damper_state=self.damper.init())  # type: ignore

    def precondition(
        self,
        params: T,
        log_p_input: tuple[*Ts],
        damping_input: D,
        dloss_dlog_p: jax.Array,
        natgrad_state: NaturalGradientState[T, S],
    ) -> tuple[T, NaturalGradientState[T, S], dict[str, jax.Array]]:
        N = dloss_dlog_p.size

        def log_p_closure(p: T):
            return self.vectorize_fn(self.log_p)(p, *log_p_input)

        _, vjp_fn = jax.vjp(log_p_closure, params)
        gradient = pmean_if_pmap(vjp_fn(dloss_dlog_p / N)[0])

        if not self.enable:
            return gradient, natgrad_state, {'damping': jnp.array(0.0)}

        _, jvp_fn = jax.linearize(log_p_closure, params)

        damping, damping_state = self.damper.update(
            damping_input, state=natgrad_state.damper_state
        )

        def Fisher_matmul(v):
            with jax.default_matmul_precision(self.precision):
                x_v = tree_mul(pmean_if_pmap(vjp_fn(v)[0]), 1 / N)
                result = jvp_fn(x_v)
                result = tree_add(result, tree_mul(v, damping))
                return result

        # Compute natural gradient
        preconditioned_dloss = cg(
            A=Fisher_matmul,
            b=dloss_dlog_p,
            x0=natgrad_state.last_grad,
            fixed_iter=jax.device_count()
            > 1,  # if we have multiple GPUs we must do a fixed number of iterations
            maxiter=self.maxiter,
            distributed_vector=True,
        )[0]

        natgrad = pmean_if_pmap(vjp_fn(preconditioned_dloss / N)[0])
        return (
            natgrad,
            NaturalGradientState(preconditioned_dloss, damping_state),
            {'damping': damping},
        )


@dataclass(frozen=True)
class BlockDiagNaturalGradientPreconditioner(
    NaturalGradientPreconditioner[T, D, S, *Ts]
):
    log_p: Callable[[T, *Ts], jax.Array]
    vectorize_fn: Callable[
        [Callable[[T, *Ts], jax.Array]], Callable[[T, *Ts], jax.Array]
    ]
    damper: Damper[D, S]
    enable: bool = True
    precision: str = 'float32'

    def init(self, params: T):
        return self.damper.init()

    def precondition(
        self,
        params: T,
        log_p_input: tuple[*Ts],
        damping_input: D,
        dloss_dlog_p: jax.Array,
        damping_state: S,
    ):
        N = dloss_dlog_p.size

        def log_p_closure(p: T):
            return self.vectorize_fn(self.log_p)(p, *log_p_input)

        _, vjp_fn = jax.vjp(log_p_closure, params)

        if not self.enable:
            return vjp_fn(dloss_dlog_p)[0], damping_state, {'damping': jnp.array(0.0)}

        damping, damping_state = self.damper.update(damping_input, state=damping_state)

        _, jvp = jax.linearize(
            lambda x: self.vectorize_fn(self.log_p)(x, *log_p_input), params
        )
        grad_f = jax.grad(self.log_p)

        @self.vectorize_fn
        def row(params, *inp_i: *Ts):
            return jvp(grad_f(params, *inp_i))

        J_TJ = row(params, *log_p_input).reshape(N, N)

        I = jnp.eye(N)
        cotangent = jnp.linalg.solve(J_TJ + damping * I, dloss_dlog_p.reshape(-1))

        natgrad = pmean_if_pmap(vjp_fn(cotangent.reshape(dloss_dlog_p.shape))[0])

        return natgrad, damping_state, {'damping': damping}


@dataclass(frozen=True)
class ExactNaturalGradientPreconditioner(NaturalGradientPreconditioner[T, D, S, *Ts]):
    log_p: Callable[[T, *Ts], jax.Array]
    vectorize_fn: Callable[
        [Callable[[T, *Ts], jax.Array]], Callable[[T, *Ts], jax.Array]
    ]
    damper: Damper[D, S]
    enable: bool = True
    precision: str = 'float32'

    def init(self, params: T):
        return self.damper.init()

    def precondition(
        self,
        params: T,
        log_p_input: tuple[*Ts],
        damping_input: D,
        dloss_dlog_p: jax.Array,
        damping_state: S,
    ):
        flat_params, unravel = jfu.ravel_pytree(params)
        N = dloss_dlog_p.size
        P = flat_params.size

        def flat_log_p(p: jax.Array, *args: *Ts):
            return self.log_p(unravel(p), *args)

        def log_p_closure(p: T):
            return self.vectorize_fn(self.log_p)(p, *log_p_input)

        _, vjp_fn = jax.vjp(log_p_closure, params)

        if not self.enable:
            return vjp_fn(dloss_dlog_p)[0], damping_state, {'damping': jnp.array(0.0)}

        damping, damping_state = self.damper.update(damping_input, state=damping_state)

        n_dev = jax.device_count()
        jacobian = self.vectorize_fn(jax.grad(flat_log_p))(flat_params, *log_p_input)
        jacobian = jacobian.reshape(-1, flat_params.size) / jnp.sqrt(N * n_dev)
        jacobian = jacobian - pmean_if_pmap(jnp.mean(jacobian, axis=0))
        if P % n_dev != 0 and n_dev > 1:
            jacobian = jnp.concatenate(
                [jacobian, jnp.zeros((N, n_dev - (P % n_dev)))], axis=-1
            )
        # part_jac now has the full n_dev*samples_per_dev on the first dim and a partial set of parameters
        # on the second dim.
        jacobian = pall_to_all(jacobian, split_axis=1, concat_axis=0, tiled=True)
        with jax.default_matmul_precision(self.precision):
            J_TJ = psum_if_pmap(jacobian @ jacobian.T)

        # Accumulate the derivative of the loss w.r.t. to the function of all samples.
        all_dloss_dlog_p = pgather(dloss_dlog_p, axis=0).reshape(-1)

        I = jnp.eye(N * n_dev)
        cotangent = jax.scipy.linalg.solve(
            J_TJ + damping * I, all_dloss_dlog_p / jnp.sqrt(N * n_dev), assume_a='pos'
        )
        cotangent = cotangent.reshape(n_dev, -1)[pidx()]

        natgrad = pmean_if_pmap(cotangent @ jacobian)
        natgrad = unravel(natgrad)

        return natgrad, damping_state, {'damping': damping}


@dataclass(frozen=True)
class DiagonalNaturalGradientPreconditioner(
    NaturalGradientPreconditioner[T, D, S, *Ts]
):
    log_p: Callable[[T, *Ts], jax.Array]
    vectorize_fn: Callable[
        [Callable[[T, *Ts], jax.Array]], Callable[[T, *Ts], jax.Array]
    ]
    damper: Damper[D, S]
    enable: bool = True
    precision: str = 'float32'

    def init(self, params: T):
        return self.damper.init()

    def precondition(
        self,
        params: T,
        log_p_input: tuple[*Ts],
        damping_input: D,
        dloss_dlog_p: jax.Array,
        damping_state: S,
    ):
        # Here we use the preconditioner
        # 1/lambda I_P - 1/lambda J(I_N + 1/lambda J^T J)^-1 J^T 1/lambda
        # we simplify this to
        # 1/lambda I_P - 1/lambda J(I_N*lambda + J^T J)^-1 J^T
        # finally we only compute the diagonal of the inverted matrix.
        N = dloss_dlog_p.size

        def log_p_closure(p: T):
            return self.vectorize_fn(self.log_p)(p, *log_p_input)

        _, vjp_fn = jax.vjp(log_p_closure, params)
        gradient = pmean_if_pmap(vjp_fn(dloss_dlog_p / N)[0])

        if not self.enable:
            return gradient, damping_state, {'damping': jnp.array(0.0)}

        damping, damping_state = self.damper.update(damping_input, state=damping_state)

        _, jvp_fn = jax.linearize(log_p_closure, params)
        J_T_grad = jvp_fn(gradient) / N

        @self.vectorize_fn
        def J_TJ_diag(params: T, *args: *Ts):
            grad = jax.grad(self.log_p)(params, *args)
            return damping * N + tree_squared_norm(grad)

        diag_inv_J_T_grad = J_T_grad / J_TJ_diag(params, *log_p_input)
        J_diag_inv_J_T_grad = vjp_fn(diag_inv_J_T_grad)[0]

        natgrad = tree_mul(tree_sub(gradient, J_diag_inv_J_T_grad), 1 / damping)
        natgrad = pmean_if_pmap(natgrad)

        return (
            natgrad,
            damping_state,
            {
                'damping': damping,
            },
        )
