import chex
import jax.numpy as jnp
from jaxopt import OSQP

from tabular_mvdrl.kernels import Kernel, kernel_matrix


def mmd2(
    k: Kernel,
    sup_p: chex.Array,
    sup_q: chex.Array,
    p: chex.Array,
    q: chex.Array,
    from_samples=False,
):
    Kpp = kernel_matrix(k, sup_p, sup_p)
    Kqq = kernel_matrix(k, sup_q, sup_q)
    Kpq = kernel_matrix(k, sup_p, sup_q)

    n = p.shape[0]
    m = q.shape[0]

    if from_samples:
        Kpp = (Kpp - jnp.diag(Kpp)) * n / (n - 1)
        Kqq = (Kqq - jnp.diag(Kqq)) * m / (m - 1)

    return p.T @ Kpp @ p + q.T @ Kqq @ q - 2 * (p.T @ Kpq @ q)


def mmd_projection_pre(
    K: chex.Array,
    k: Kernel,
    fixed_support: chex.Array,
    p_support: chex.Array,
    p_probs: chex.Array,
    signed: bool = False,
    tol: float = 1e-3,
    maxiter: int = 4000,
):
    """
    Compute categorical MMD projection of the categorical distribution with
    support `p_support` and probability masses `p_probs` onto `fixed_support`.
    """
    b = kernel_matrix(k, fixed_support, p_support) @ p_probs

    n = fixed_support.shape[0]

    qp = OSQP(tol=tol, maxiter=maxiter)

    qp_params = {"params_obj": (K, -b), "params_eq": (jnp.ones((n, n)), jnp.ones(n))}
    if not signed:
        qp_params = qp_params | {"params_ineq": (-jnp.eye(n), jnp.zeros(n))}

    sol = qp.run(**qp_params)

    return sol.params.primal


def mmd_projection(
    k: Kernel,
    fixed_support: chex.Array,
    p_support: chex.Array,
    p_probs: chex.Array,
    signed: bool = False,
    tol: float = 1e-3,
    maxiter: int = 4000,
) -> chex.Array:
    """
    Compute categorical MMD projection of the categorical distribution with
    support `p_support` and probability masses `p_probs` onto `fixed_support`.
    """
    K = kernel_matrix(k, fixed_support, fixed_support)
    return mmd_projection_pre(
        K, k, fixed_support, p_support, p_probs, tol=tol, maxiter=maxiter
    )
