import flax.linen as nn
import jax.nn as jnn
import jax.numpy as jnp
import numpy as np
import pyscf
from pyscf.scf.hf import SCF, RHF
from jax.scipy.special import gammaln
from jaxtyping import Array, ArrayLike, Float, Integer

from globe.nn.parameters import inverse_softplus


def get_cartesian_angulars(l: int) -> list[tuple[int, int, int]]:
    return [
        (lx, ly, l - lx - ly)
        for lx in np.arange(l, -1, -1)
        for ly in np.arange(l - lx, -1, -1)
    ]


def factorial2(n: Integer[ArrayLike, '...']) -> Float[Array, '...']:
    n = jnp.array(n)
    gamma = jnp.exp(gammaln(n / 2 + 1).astype(jnp.float32))
    factor = jnp.where(
        n % 2, jnp.power(2, n / 2 + 0.5) / jnp.sqrt(jnp.pi), jnp.power(2, n / 2)
    )
    return factor * gamma


def safe_power(x, power, eps=1e-8):
    # 0.0^0.0 has undefined gradients,
    # but jnp.power has further issues with every power
    # if x == 0.0
    x = jnp.where(jnp.abs(x) < eps, x + eps, x)
    y = jnp.power(x, power)
    return y


class GTOShell(nn.Module):
    atom_idx: int
    l: int
    coeffs: ArrayLike
    zetas: ArrayLike

    @nn.compact
    def __call__(self, diffs):
        coeffs = self.param(
            'coeffs', init_fn=lambda *_: jnp.array(self.coeffs, dtype=diffs.dtype)
        )
        zetas = jnn.softplus(
            self.param(
                'zetas',
                init_fn=lambda *_: inverse_softplus(
                    jnp.abs(self.zetas).astype(diffs.dtype)
                ),
            )
        )

        ls = jnp.array(get_cartesian_angulars(self.l), dtype=jnp.int32)
        diffs = diffs[..., self.atom_idx, :]
        diffs, dist = diffs[..., :3], diffs[..., 3:]

        rnorms = ((2 * zetas / np.pi) ** (3 / 4) * (4 * zetas) ** (self.l / 2)).astype(
            diffs.dtype
        )
        anorms = 1.0 / jnp.sqrt(factorial2(2 * ls - 1).prod(axis=-1))

        angulars = safe_power(diffs[..., None, :], ls).prod(-1)
        exps = rnorms * jnp.exp(-jnp.abs(zetas * dist**2))
        radials = (coeffs * exps).sum(-1, keepdims=True)
        phis = anorms * angulars * radials
        return phis


class GTOBasis(nn.Module):
    mol: pyscf.gto.Mole

    @nn.compact
    def __call__(self, electrons):
        shells = []
        for i in range(self.mol.nbas):
            shells.append(
                GTOShell(
                    self.mol.bas_atom(i),
                    self.mol.bas_angular(i),
                    self.mol.bas_ctr_coeff(i).reshape(-1),
                    self.mol.bas_exp(i),
                )
            )

        center = self.param(
            'center',
            init_fn=lambda *_: jnp.array(self.mol.atom_coords(), dtype=electrons.dtype),
        )
        diffs = electrons[:, None] - center
        diffs = jnp.concatenate(
            [diffs, jnp.linalg.norm(diffs, axis=-1, keepdims=True)], axis=-1
        )
        result = jnp.concatenate([shell(diffs) for shell in shells], -1)
        return result


class HF(nn.Module):
    mean_field: SCF
    mo_coeff: ArrayLike

    @nn.compact
    def __call__(self, electrons):
        mo_coeff = self.param(
            'mo_coeff',
            init_fn=lambda *_: jnp.array(self.mo_coeff, dtype=electrons.dtype),
        )
        ao = GTOBasis(self.mean_field.mol)(electrons)
        spin_up = int(self.mean_field.mol.nelec[0])
        aos = jnp.split(ao, np.array([spin_up]), axis=0)
        if isinstance(self.mean_field, RHF):
            mo_coeff = jnp.repeat(mo_coeff, 2, axis=0)
        return tuple(
            jnp.einsum('...ea,...am->...em', ao, coeff[..., : ao.shape[-2]])
            for ao, coeff in zip(aos, mo_coeff)
        )
