# -*- coding: utf-8 -*-
"""
Created on Mon Apr 29 21:44:15 2024

@author: javie
"""
# Group and OT Utilities
import numpy as np
import jax.numpy as jnp
from jax import vmap
import objax
import ot

def equivariance_tuple(x, modelo, G, repin, repout):
    """

    Parameters
    ----------
    x : array
        input array.
    modelo : objax.nn.module (or Callable)
        function to be evaluated.
    G : emlp.group
        group involved in the equivariance calculations.
    repin : emlp.representation
        group representation acting on input space.
    repout : emlp.representation
        group representation acting on output space.

    Returns
    -------
    tuple
        evaluation of modelo(g.x), g.modelo(x).

    """
    x = jnp.array(x)
    gs = G.samples(x.shape[0])
    rho_gin = vmap(repin.rho_dense)(gs)
    rho_gout = vmap(repout.rho_dense)(gs)
    y1 = modelo((rho_gin@x[...,None])[...,0],training=False)
    y2 = (rho_gout@modelo(x,training=False)[...,None])[...,0]
    return y1, y2

def rel_err(a,b):
    # Calculates the "relative error" between two quantities (i.e. the standard deviation divided by the total variance)
    return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean()))#

def equivariance_err(x, modelo, G, repin, repout):
    """
    

    Parameters
    ----------
    x : array
        input array.
    modelo : objax.nn.module (or Callable)
        function to be evaluated.
    G : emlp.group
        group involved in the equivariance calculations.
    repin : emlp.representation
        group representation acting on input space.
    repout : emlp.representation
        group representation acting on output space.

    Returns
    -------
    float
        relative error calculated between modelo(g.x) and g.modelo(x), to evaluate the "equivariance" of the model.

    """
    return rel_err(*equivariance_tuple(x,modelo, G, repin, repout))


def Wasserstein_Distance(m1, m2, w1=[], w2=[], p=2, root=False):
    """

    Parameters
    ----------
    m1 : array
        input "empirical measure" given by array of shape (n_samples, dimension).
    m2 : array
        input "empirical measure" given by array of shape (n_samples, dimension).
    w1 : array, optional
        weights on the input "empirical measure" to calculate the wasserstein loss. The default is [].
    w2 : array, optional
        weights on the input "empirical measure" to calculate the wasserstein loss. The default is [].
    p : int, optional
        Exponent for the Wasserstein distance calculation. The default is 2.
    root : bool, optional
        boolean stating whether the p-th root is to be calculated on the output. The default is False.

    Returns
    -------
    float
        Wasserstein distance (potentially weighted or to the p-th power) between m1 and m2.

    """
    # m1 and m2 are arrays of shape (n_samples, dimension)
    distance_matrix = ot.dist(np.asarray(m1), np.asarray(m2), metric='sqeuclidean', p=p, w=None)
    W_loss = ot.emd2(a=w1, b=w2, M=distance_matrix)
    return W_loss if not root else np.power(W_loss, 1/p)

def rel_measure_distance(m1, m2, p=2, root = False, mode = "trace", return_Wasserstein=False):
    """

    Parameters
    ----------
    m1 : array
        input "empirical measure" given by array of shape (n_samples, dimension).
    m2 : array
        input "empirical measure" given by array of shape (n_samples, dimension).
    p : int, optional
        Exponent for the Wasserstein distance calculation. The default is 2.
    root : bool, optional
        boolean stating whether the p-th root is to be calculated on the output. The default is False.
    mode : str, optional
        One of ["trace", "det"]. It indicates whether the "total variance" is 
        to be calculated via the trace or determinant of the covariance matrix. 
        The default is "trace".

    Returns
    -------
    TYPE
        DESCRIPTION.

    """
    W_dist = Wasserstein_Distance(m1, m2, p=p, root=root)
    method = {"trace" : lambda x: jnp.trace(x)/x.shape[-1], "det": lambda x: jnp.power(jnp.linalg.det(x), 1/x.shape[-1])}
    cov_matrix1 = jnp.cov(m1, rowvar=False)
    cov_matrix2 = jnp.cov(m2, rowvar=False)
    std1, std2 = method[mode](cov_matrix1), method[mode](cov_matrix2)
    total_variation = (std1**p + std2**p) if not root else jnp.power((std1**p + std2**p), 1/p)
    return (W_dist/total_variation).item() if not return_Wasserstein else ((W_dist/total_variation).item(), W_dist)


def inner_approx(f1, f2, samples = 10000, stddev = 1, input_dim=2, seed = 0):
    MC_data = objax.random.normal((samples, input_dim), stddev=stddev, generator=objax.random.Generator(seed=seed))
    inner = (f1(MC_data)*f2(MC_data)).sum(axis=1)
    return inner.mean().item()

def norm_approx(f, samples = 10000, stddev = 1, input_dim=2, seed = 0):
    MC_data = objax.random.normal((samples, input_dim), stddev=stddev, generator=objax.random.Generator(seed=seed))
    return (f(MC_data)**2).sum(axis=1).mean().item()

def generalization_err(model_t, model_s, samples = 10000, stddev = 1, input_dim=2, seed = 0):
    MC_data = objax.random.normal((samples, input_dim), stddev=stddev, generator=objax.random.Generator(seed=seed))
    y_t, y_s = model_t(MC_data), model_s(MC_data)
    return ((y_s-y_t)**2).sum(axis=1).mean()

def symmetrization_gap(teacher_network, model, model_G, model_perp, approx_samples, stddev, kernel = None):
  inner = inner_approx(teacher_network, model_perp, samples = approx_samples, stddev=stddev)
  print("Correlation between teacher and 'perpendicular' model: ", inner)
  LHS = generalization_err(teacher_network, model, samples = approx_samples, stddev=stddev) - generalization_err(teacher_network, model_G, samples = approx_samples, stddev=stddev)
  RHS1 = norm_approx(model_perp, samples = approx_samples, stddev=stddev) - 2*inner
  
  if not kernel is None:
    kernel_fn, vG_generator = kernel
    particle_positions_model = model.get_particles()
    particle_positions_teacher = teacher_network.get_particles()
    K1 = kernel_fn(particle_positions_model, particle_positions_model)
    K2 = kernel_fn(particle_positions_model, vG_generator(particle_positions_model))
    K3 = kernel_fn(particle_positions_teacher, vG_generator(particle_positions_model))
    K4 = kernel_fn(particle_positions_teacher, particle_positions_model)

    RHS2 = (1/2)*(K1 - K2).mean() + (K3-K4).mean()
    return LHS, RHS1, RHS2
  return LHS, RHS1, None