# ruff: noqa: E402
# ruff: noqa: F841
import functools
import json
import logging
import math
import os

os.environ['JAX_DEFAULT_DTYPE_BITS'] = '32'

from collections import Counter, defaultdict

import flax
import jax
import jax.numpy as jnp
import numpy as np
from folx import batched_vmap
from seml.experiment import Experiment
from seml_logger import Logger, automain, add_default_observer_config

import globe.systems as Systems
import globe.systems.property as Properties
from globe.systems.dataset import Dataset
from globe.nn.globe import Globe
from globe.vmc.hamiltonian import make_local_energy_function

from globe.vmc.mcmc import make_mcmc
from globe.utils.jax_utils import pmap, broadcast, p_split

jax.config.update('jax_enable_x64', True)
jax.config.update('jax_default_matmul_precision', 'float32')


ex = Experiment()
add_default_observer_config(ex, notify_on_completed=True)


@ex.config
def config():
    chkpt_dir = None
    chkpt = 'final'
    systems = None

    total_samples = 1_000_000
    batch_size = 8
    samples_per_batch = 4096
    mcmc_steps = 50
    therm_steps = 5000
    max_energy_batch_size = 64

    run_name = None
    print_progress = True
    folder = '~/logs/pfaffian_debug'


def naming_fn(systems, chkpt_dir, chkpt, run_name):
    if run_name is not None:
        return run_name
    if systems is None:
        with open(os.path.join(chkpt_dir, 'config.json'), 'r') as inp:
            systems = json.load(inp)['systems']
    molecules = Systems.get_molecules(systems)
    return f'{chkpt}_' + '-'.join([f'{k}_{c}' for k, c in Counter(molecules).items()])


@automain(ex, naming_fn, default_folder='~/logs/pfaffian_neurips_eval')
def main(
    chkpt_dir: str,
    chkpt: str,
    systems: dict | None,
    total_samples: int,
    batch_size: int,
    samples_per_batch: int,
    mcmc_steps: int,
    therm_steps: int,
    max_energy_batch_size: int,
    seed: int,
    run_name: str | None,
    logger: Logger = None,
):
    n_devices = jax.device_count()
    key = jax.random.PRNGKey(seed)
    key, *subkeys = jax.random.split(key, n_devices + 1)
    pkey = broadcast(jnp.stack(subkeys))

    logging.info(f'Using devices: {jax.devices()}')

    config_file = os.path.join(chkpt_dir, 'config.json')
    with open(config_file, 'r') as inp:
        config = json.load(inp)

    if systems is None:
        systems = config['systems']

    chkpt = os.path.join(chkpt_dir, f'chk_{chkpt}.chk')
    with open(chkpt, 'rb') as inp:
        params = flax.serialization.msgpack_restore(inp.read())

    network = Globe(**config['globe'])
    wf = functools.partial(network.apply, method=network.wf)
    get_mol_params = functools.partial(
        network.apply, params, method=network.get_mol_params
    )
    get_mol_params = jax.jit(get_mol_params, static_argnums=1)

    energy = make_local_energy_function(
        wf, network.group_parameters, operator='forward'
    )
    energy = functools.partial(energy, params)
    energy = batched_vmap(
        energy, in_axes=(0, None, None, None), max_batch_size=max_energy_batch_size
    )
    energy = pmap(energy, in_axes=(0, None, None, None), static_broadcasted_argnums=2)

    mcmc = make_mcmc(functools.partial(network.apply, method=network.wf), mcmc_steps)
    mcmc = functools.partial(mcmc, params)
    mcmc = pmap(
        mcmc, in_axes=(0, None, None, None, 0, None), static_broadcasted_argnums=2
    )

    # Initialize dataset
    mols = Systems.get_molecules(systems)
    # we divide and multiple by n_devices to ensure that the batches can be parallized across multiple GPUs.
    key, subkey = jax.random.split(key)
    samples_per_mol = int(((samples_per_batch / batch_size) // n_devices) * n_devices)
    dataset = Dataset(
        subkey,
        mols,
        'random',
        'none',
        batch_size,
        samples_per_mol,
        (functools.partial(Properties.WidthScheduler, init_width=0.02),),
        True,
        'STO-6G',
        pretrain_localization='hf',
    )

    logging.info('Thermalizing')
    for batch in logger.tqdm(dataset):
        electrons, atoms, config, props = batch.to_jax()
        mol_params = get_mol_params(atoms, config)
        for _ in range(therm_steps):
            electrons, atoms, config, props = batch.to_jax()
            pkey, psubkey = p_split(pkey)
            electrons, pmove = mcmc(
                electrons, atoms, config, mol_params, psubkey, props['mcmc_width']
            )
            batch.update_states(electrons, pmove=pmove)

    logging.info('Energy computation')
    result = defaultdict(lambda *_: np.zeros((0,)))
    n_steps = int(math.ceil(total_samples / samples_per_mol))
    for batch in logger.tqdm(dataset):
        electrons, atoms, config, props = batch.to_jax()
        mol_params = get_mol_params(atoms, config)
        keys = [repr(m) for m in batch.molecules]
        for _ in range(n_steps):
            electrons, atoms, config, props = batch.to_jax()
            pkey, psubkey = p_split(pkey)
            electrons, pmove = mcmc(
                electrons, atoms, config, mol_params, psubkey, props['mcmc_width']
            )
            batch.update_states(electrons, pmove=pmove)
            energies = energy(electrons, atoms, config, mol_params)
            energies = np.asarray(energies)
            energies = np.transpose(energies, (2, 0, 1)).reshape(len(keys), -1)
            for key, E in zip(keys, energies):
                result[key] = np.concatenate([result[key], E])
    result = {k: v for k, v in result.items()}
    return result
