import pickle
from functools import wraps
from typing import Any, Callable, ParamSpec, TypeVar

import jax
import jax.numpy as jnp
import numpy as np

R = TypeVar('R')
P = ParamSpec('P')


@jax.jit
def test_nan(tree):
    return jax.tree_util.tree_reduce(
        jnp.logical_or, jax.tree_util.tree_map(lambda x: jnp.isnan(x).any(), tree)
    )


def try_to_numpy(array):
    try:
        if isinstance(array, jax.Array):
            return np.asarray(array)
        else:
            return array
    except RuntimeError:
        return 'deleted'


def dump_tree(tree, file):
    to_log = jax.tree_util.tree_map(try_to_numpy, tree)
    pickle.dump(to_log, file)


def add_nan_catch(fun: Callable[P, R], extras: Any = None) -> Callable[P, R]:
    @wraps(fun)
    def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
        result = fun(*args, **kwargs)
        is_nan = test_nan(result).item()
        if is_nan:
            with open('dump.pickle', 'wb') as f:
                dump_tree(
                    {
                        'args': args,
                        'kwargs': kwargs,
                        'result': result,
                        'extras': extras,
                    },
                    f,
                )
            raise ValueError(
                "NaN encountered, input and output has been dumped to 'dump.pickle'."
            )
        return result

    return wrapped
