"""
Utilities for working with JAX.
Some of these functions are taken from
https://github.com/deepmind/ferminet/tree/jax/ferminet
"""

import functools
from typing import Callable, ParamSpec, TypeVar
from chex import ArrayTree, PRNGKey

import jax
import jax.tree_util as jtu
from jax import core


T = TypeVar('T')


broadcast = jax.pmap(lambda x: x)
instance = functools.partial(jtu.tree_map, lambda x: x[0])

_p_split = jax.pmap(lambda key: tuple(jax.random.split(key)))


def p_split(key: PRNGKey) -> tuple[PRNGKey, ...]:
    return _p_split(key)


Tree = TypeVar('Tree', bound=ArrayTree)


def replicate(pytree: T) -> T:
    n = jax.local_device_count()
    stacked_pytree = jtu.tree_map(lambda x: jax.lax.broadcast(x, (n,)), pytree)
    return broadcast(stacked_pytree)


# Axis name we pmap over.
PMAP_AXIS_NAME = 'qmc_pmap_axis'

# Shortcut for jax.pmap over PMAP_AXIS_NAME. Prefer this if pmapping any
# function which does communications or reductions.
P = ParamSpec('P')
R = TypeVar('R')


def pmap(fn: Callable[P, R], *pmap_args, **pmap_kwargs) -> Callable[P, R]:
    out_fn = jax.pmap(fn, *pmap_args, **pmap_kwargs, axis_name=PMAP_AXIS_NAME)

    @functools.wraps(fn)
    def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
        return out_fn(*args, **kwargs)

    return wrapped


pmean = functools.partial(jax.lax.pmean, axis_name=PMAP_AXIS_NAME)
psum = functools.partial(jax.lax.psum, axis_name=PMAP_AXIS_NAME)
pmax = functools.partial(jax.lax.pmax, axis_name=PMAP_AXIS_NAME)
pmin = functools.partial(jax.lax.pmin, axis_name=PMAP_AXIS_NAME)
pgather = functools.partial(jax.lax.all_gather, axis_name=PMAP_AXIS_NAME)
pall_to_all = functools.partial(jax.lax.all_to_all, axis_name=PMAP_AXIS_NAME)
pidx = functools.partial(jax.lax.axis_index, axis_name=PMAP_AXIS_NAME)


def wrap_if_pmap(p_func: Callable[[T], T]) -> Callable[[T], T]:
    @functools.wraps(p_func)
    def p_func_if_pmap(obj: T) -> T:
        try:
            core.axis_frame(PMAP_AXIS_NAME)
            return p_func(obj)
        except NameError:
            return obj

    return p_func_if_pmap


pmean_if_pmap = wrap_if_pmap(pmean)
psum_if_pmap = wrap_if_pmap(psum)
pmax_if_pmap = wrap_if_pmap(pmax)
pmin_if_pmap = wrap_if_pmap(pmin)
pgather_if_pmap = wrap_if_pmap(pgather)


P = ParamSpec('P')
R = TypeVar('R')
T = TypeVar('T')


def jit(fn: T, *jit_args, **jit_kwargs) -> T:
    jit_fn = jax.jit(fn, *jit_args, **jit_kwargs)  # type: ignore

    @functools.wraps(fn)  # type: ignore
    def wrapped(*args, **kwargs):
        return jit_fn(*args, **kwargs)

    return wrapped  # type: ignore
