import dataclasses
from typing import Callable

import chex
import einops
import jax
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers import linear as linear_solver
from typing_extensions import Self

from tabular_mvdrl.kernels import Kernel
from tabular_mvdrl.mmd import mmd2


@dataclasses.dataclass(frozen=True, kw_only=True)
class DiscreteDistribution:
    locs: chex.Array
    probs: chex.Array

    @staticmethod
    def empirical_from(locs: chex.Array):
        num_atoms = locs.shape[0]
        return DiscreteDistribution(locs=locs, probs=jnp.ones(num_atoms) / num_atoms)

    def pushforward_linear(self, reward: chex.Array) -> Self:
        """
        Compute the pushforward of a linear function through a distribution.
        This function is useful for applying a batch of pushforwards,
        e.g. by mapping over distributions and/or reward functions.
        """
        locs = einops.einsum(reward, self.locs, "r ..., n r -> n ...")
        return DiscreteDistribution(locs=locs, probs=self.probs)

    def _tree_flatten(self):
        return ((self.locs, self.probs), [])

    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        del aux_data
        return cls(locs=children[0], probs=children[1])


jax.tree_util.register_pytree_node(
    DiscreteDistribution,
    DiscreteDistribution._tree_flatten,
    DiscreteDistribution._tree_unflatten,
)

ProbabilityMetric = Callable[[DiscreteDistribution, DiscreteDistribution], float]


@dataclasses.dataclass(frozen=True)
class SupremalMetric:
    base_metric: ProbabilityMetric

    def __call__(self, d1: DiscreteDistribution, d2: DiscreteDistribution):
        base_metrics = jax.vmap(self.base_metric)(d1, d2)
        return jnp.max(base_metrics)


@dataclasses.dataclass(frozen=True)
class SquaredMMDMetric:
    kernel: Kernel

    def __call__(self, d1: DiscreteDistribution, d2: DiscreteDistribution):
        return mmd2(self.kernel, d1.locs, d2.locs, d1.probs, d2.probs)


@dataclasses.dataclass(frozen=True, kw_only=True)
class Wasserstein2Metric:
    epsilon: float = 0.001
    threshold: float = 0.001
    max_iterations: int = 1000

    def __call__(self, d1: DiscreteDistribution, d2: DiscreteDistribution):
        def _ensure_reward_dim(x: chex.Array) -> chex.Array:
            if len(x.shape) == 1:
                return jnp.expand_dims(x, -1)
            return x

        d1_locs = _ensure_reward_dim(d1.locs)
        d2_locs = _ensure_reward_dim(d2.locs)
        geom = pointcloud.PointCloud(d1_locs, d2_locs, epsilon=self.epsilon)
        prob = linear_problem.LinearProblem(geom, a=d1.probs, b=d2.probs)
        solver = linear_solver.sinkhorn.Sinkhorn(
            threshold=self.threshold, max_iterations=self.max_iterations, norm_error=2
        )
        out = solver(prob)
        return out.reg_ot_cost
