import functools
from typing import Sequence

import flax
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
from chex import PRNGKey
from flax.core import unfreeze
from folx import batched_vmap
from jaxtyping import PyTree

from globe.nn.globe import Globe
from globe.nn.parameters import ParamTree
from globe.utils.config import SystemConfigs
from globe.utils.jax_utils import broadcast, instance, p_split, pmap, replicate
from globe.utils.optim import (
    NaturalGradientState,
    make_memefficient_spring_preconditioner,
    make_block_spring_preconditioner,
    make_natural_gradient_preconditioner,
    make_schedule,
    scale_by_trust_ratio_embeddings,
)
from globe.vmc.hamiltonian import make_local_energy_function
from globe.vmc.mcmc import make_mcmc
from globe.vmc.optim import (
    make_loss_and_natural_gradient_fn,
    make_std_based_damping_fn,
    make_training_step,
)
from globe.vmc.pretrain import make_mol_param_loss, make_pretrain_step


class Trainer:
    """
    Trainer class for training the electronic wave function.
    Provides methods for training the wave function with VMC and for pretraining the molecular parameters.
    Also, provides methods for saving and loading the parameters.

    Args:
    - key: random key
    - mbnet: dictionary of parameters for the electronic wave function
    - mcmc_steps: number of MCMC steps
    - cg: dictionary of parameters for the conjugate gradient solver
    - loss: dictionary of parameters for the loss function
    - lr: dictionary of parameters for the learning rate schedule
    - damping: dictionary of parameters for the damping schedule
    """

    def __init__(
        self,
        key: PRNGKey,
        globe: dict,
        mcmc_steps: int,
        preconditioner: str,
        preconditioner_args: dict,
        loss: dict,
        lr: dict,
        damping: dict,
        operator: str,
        batched_vmap_size: int = 64,
    ):
        self.globe_params = globe
        self.preconditioner = preconditioner
        self.preconditioner_args = preconditioner_args
        self.loss_params = loss
        self.lr_params = lr
        self.damping_params = damping
        self.mcmc_steps = mcmc_steps
        self.operator = operator
        # Prepare random keys
        self.key, *subkeys = jax.random.split(key, jax.device_count() + 1)
        self.shared_key = broadcast(jnp.stack(subkeys))

        # Prepare all necessary functions
        self.network = Globe(**globe)
        self.wf = functools.partial(self.network.apply, method=self.network.wf)
        self.p_wf = pmap(
            jax.vmap(self.wf, in_axes=(None, 0, None, None, None)),
            in_axes=(0, 0, None, None, 0),
            static_broadcasted_argnums=3,
        )
        self.fwd = pmap(
            jax.vmap(self.network.apply, in_axes=(None, 0, None, None)),
            in_axes=(0, 0, None, None),
            static_broadcasted_argnums=3,
        )
        self.signed = pmap(
            jax.vmap(
                functools.partial(self.network.apply, method=self.network.signed),
                in_axes=(None, 0, None, None),
            ),
            in_axes=(0, 0, None, None),
            static_broadcasted_argnums=3,
        )
        self.get_mol_params = functools.partial(
            self.network.apply, method=self.network.get_mol_params
        )
        self.get_mol_params = jax.jit(self.get_mol_params, static_argnums=2)
        self.p_get_mol_params = pmap(
            self.get_mol_params, in_axes=(0, None, None), static_broadcasted_argnums=2
        )
        self.get_intermediates = jax.vmap(
            functools.partial(self.network.apply, capture_intermediates=True),
            in_axes=(None, 0, None, None),
        )
        self.p_get_intermediates = pmap(
            self.get_intermediates,
            in_axes=(0, 0, None, None),
            static_broadcasted_argnums=3,
        )

        # Initialize parameters
        test_charge = globe['gnn_params'].get('charges', [2])[-1]
        if globe['orbital_type'] == 'globe':
            test_charge = 2
        self.init_params(
            jnp.ones((2, 3)),
            jnp.zeros((1, 3)),
            SystemConfigs(spins=((1, 1),), charges=((test_charge,),)),
        )

        # Prepare VMC training functions
        self.wf_mcmc = make_mcmc(self.wf, self.mcmc_steps)
        self.p_wf_mcmc = pmap(
            self.wf_mcmc,
            in_axes=(0, 0, None, None, 0, 0, None),
            static_broadcasted_argnums=3,
            donate_argnums=1,
        )

        @functools.partial(jax.jit, static_argnums=3)
        def mcmc(params, electrons, atoms, config, key, widths):
            mol_params = self.get_mol_params(params, atoms, config)
            return self.wf_mcmc(
                params, electrons, atoms, config, mol_params, key, widths
            )

        self.mcmc = mcmc

        pre_mcmc = make_mcmc(self.wf, 5)

        def p_mcmc(params, electrons, atoms, config, key, widths):
            mol_params = self.get_mol_params(params, atoms, config)
            return pre_mcmc(params, electrons, atoms, config, mol_params, key, widths)

        self.wf_energy = make_local_energy_function(
            self.wf, self.network.group_parameters, operator
        )

        @functools.partial(jax.jit, static_argnums=3)
        def energy_fn(params, electrons, atoms, config):
            mol_params = self.get_mol_params(params, atoms, config)
            return self.wf_energy(params, electrons, atoms, config, mol_params)

        self.energy = energy_fn
        if operator.lower() == 'forward':
            self.v_energy = batched_vmap(
                self.energy,
                max_batch_size=batched_vmap_size,
                in_axes=(None, 0, None, None),
            )
        else:
            self.v_energy = batched_vmap(
                self.energy, max_batch_size=256, in_axes=(None, 0, None, None)
            )
            # self.v_energy = jax.vmap(self.energy, in_axes=(None, 0, None, None))
        self.p_energy = pmap(
            self.v_energy, in_axes=(0, 0, None, None), static_broadcasted_argnums=3
        )

        if self.preconditioner == 'cg':
            self.natgrad_precond = make_natural_gradient_preconditioner(
                self.network, **self.preconditioner_args['cg']
            )
            self.optimizer = optax.chain(
                optax.clip_by_global_norm(1.0),
                optax.scale_by_schedule(make_schedule(lr)),
                optax.scale(-1.0),
            )
        elif self.preconditioner == 'spring':
            self.natgrad_precond = make_memefficient_spring_preconditioner(
                self.network, **self.preconditioner_args['spring']
            )
            self.optimizer = optax.chain(
                optax.scale_by_schedule(lambda t: -0.02 / (1 + 1e-4 * t)),
                optax.clip_by_global_norm(1e-3**0.5),
            )
        elif self.preconditioner == 'block-spring':
            self.natgrad_precond = make_block_spring_preconditioner(
                self.network, **self.preconditioner_args['spring']
            )
            self.optimizer = optax.chain(
                optax.scale_by_schedule(lambda t: -0.02 / (1 + 1e-4 * t)),
                optax.clip_by_global_norm(1e-3**0.5),
            )
        else:
            raise ValueError(f'Unknown preconditioner: {self.preconditioner}')

        self.loss_and_grad = make_loss_and_natural_gradient_fn(
            self.network.apply,
            natgrad_precond=self.natgrad_precond,
            **self.loss_params,
        )

        self.damping_fn = make_std_based_damping_fn(**self.damping_params)
        self.train_step = make_training_step(
            self.mcmc,
            self.loss_and_grad,
            self.v_energy,
            self.damping_fn,
            self.optimizer.update,
        )

        # Prepare pretraining functions
        kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
        embeddings = flax.traverse_util.ModelParamTraversal(
            lambda p, _: 'embedding' in p
        )
        prefixed_params = flax.traverse_util.ModelParamTraversal(
            lambda p, _: 'prefixed' in p
        )

        all_false = jtu.tree_map(lambda _: False, self.params['params'])
        kernel_mask = kernels.update(lambda _: True, all_false)
        embedding_mask = embeddings.update(lambda _: True, all_false)
        prefixed_mask = prefixed_params.update(lambda _: True, all_false)

        self.pre_optimizer = optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.masked(optax.set_to_zero(), prefixed_mask),
            optax.scale_by_adam(),
            optax.masked(optax.scale_by_trust_ratio(), kernel_mask),
            optax.masked(scale_by_trust_ratio_embeddings(), embedding_mask),
            optax.scale_by_schedule(lambda t: -1e-3 / (1 + 1e-4 * t)),
        )

        self.pre_step = make_pretrain_step(
            p_mcmc,
            self.get_mol_params,
            self.network.apply,
            functools.partial(
                self.network.apply, function='orbitals', method=self.network.wf
            ),
            self.pre_optimizer.update,
            orbital_matching_fn=self.network.match_orbitals,
            mol_param_aux_loss=make_mol_param_loss(
                self.network.param_spec(),
                1e-6 if self.network.meta_model != 'none' else 0,
            ),
        )

        # Init data
        self._opt_state = None
        self._pre_opt_state = None
        self.iteration = 0
        self.init_natgrad_state()

        self.last_grad_norm = 1e8

    def init_params(
        self, electrons: jax.Array, atoms: jax.Array, config: SystemConfigs
    ):
        """
        Initialize the parameters of the network.

        Args:
        - electrons: The electrons in the system.
        - atoms: The atoms in the system.
        - config: The system configuration.
        """
        self.key, subkey = jax.random.split(self.key)
        self.params = unfreeze(self.network.init(subkey, electrons, atoms, config))
        self.params = replicate(self.params)

    def init_natgrad_state(self):
        """
        Initialize the natural gradient state.
        """
        self.natgrad_state = NaturalGradientState(
            damping=replicate(
                jnp.array(self.damping_params['init'], dtype=jnp.float64)
            ),
            last_grad=broadcast(
                jtu.tree_map(
                    lambda x: jnp.zeros_like(x, dtype=jnp.float64),
                    self.params['params'],
                )
            ),
        )

    def mol_params(self, atoms: jax.Array, config: SystemConfigs):
        """
        Get the adaptive parameters for the specified systems.

        Args:
        - atoms: The atoms in the system.
        - config: The system configurations.
        """
        return self.get_mol_params(instance(self.params), atoms, config)

    @property
    def opt_state(self):
        if self._opt_state is None:
            self._opt_state = pmap(self.optimizer.init)(self.params['params'])
            self.iteration = 0
            self.init_natgrad_state()
        return self._opt_state

    @opt_state.setter
    def opt_state(self, val):
        self._opt_state = val

    @property
    def pre_opt_state(self):
        if self._pre_opt_state is None:
            self._pre_opt_state = pmap(self.pre_optimizer.init)(self.params['params'])
            self.init_natgrad_state()
        return self._pre_opt_state

    @pre_opt_state.setter
    def pre_opt_state(self, val):
        self._pre_opt_state = val

    def next_key(self):
        self.shared_key, result = p_split(self.shared_key)
        return result

    def intermediates(
        self, electrons: jax.Array, atoms: jax.Array, config: SystemConfigs
    ) -> ParamTree:
        """
        Get all intermediate tensors of the network's forward pass.

        Args:
        - electrons: The electrons in the system.
        - atoms: The atoms in the system.
        - config: The system configuration.
        Returns:
        - intermediates: The intermediate tensors.
        """
        return self.p_get_intermediates(self.params, electrons[:1, :1], atoms, config)[
            1
        ]

    def pretrain_step(
        self,
        electrons: jax.Array,
        atoms: jax.Array,
        config: SystemConfigs,
        target_fns: Sequence,
        properties: dict[str, jax.Array],
        cache: Sequence[PyTree],
    ):
        """
        Perform a single pretraining step.

        Args:
        - electrons: The electrons in the systems.
        - atoms: The atoms in the systems.
        - config: The system configurations.
        - scfs: The SCF objects for the systems.
        - properties: The properties of the systems.
        Returns:
        - loss: The loss of the pretraining step.
        - electrons: The electrons in the systems.
        - pmove: The proportion of electrons that moved.
        """
        (
            self.params,
            electrons,
            self.pre_opt_state,
            self.natgrad_state,
            losses,
            pmove,
            cache,
        ) = self.pre_step(
            self.params,
            electrons,
            atoms,
            config,
            target_fns,
            self.pre_opt_state,
            self.next_key(),
            properties,
            self.natgrad_state,
            cache,
        )
        return losses, electrons, pmove, cache

    def step(
        self,
        electrons: jax.Array,
        atoms: jax.Array,
        config: SystemConfigs,
        properties: dict[str, jax.Array],
    ):
        """
        Perform a single training step.

        Args:
        - electrons: The electrons in the systems.
        - atoms: The atoms in the systems.
        - config: The system configurations.
        - properties: The properties of the systems.
        Returns:
        - electrons: The electrons in the systems.
        - mol_data: Data per molecular structure.
        - aux_data: The auxiliary data to log.
        """
        (
            (electrons, self.params, self.opt_state, self.natgrad_state),
            mol_data,
            aux_data,
        ) = self.train_step(
            self.params,
            electrons,
            atoms,
            config,
            self.opt_state,
            self.next_key(),
            self.natgrad_state,
            properties,
        )
        self.last_grad_norm = np.array(aux_data['grad_norm']['final']).ravel()[0]
        self.iteration += 1
        return electrons, mol_data, aux_data

    def serialize_params(self) -> bytes:
        """
        Serialize the parameters of the network.
        """
        to_store = instance(self.params)
        return flax.serialization.msgpack_serialize(to_store)

    def load_params(self, blob: bytes):
        """
        Load the parameters of the network.

        Args:
        - blob: The serialized parameters.
        """
        data = flax.serialization.msgpack_restore(blob)
        self.params = replicate(data)
