import sys
import os
from functools import reduce
import argparse

# Argument parser setup
parser = argparse.ArgumentParser(description='Run Main Experiments')

# Adding arguments for hyperparameters with their default values
parser.add_argument('--path_save', type=str, default="", help='Path where the results will be saved')
parser.add_argument('--batch_size', type=int, default=20, help='Batch size for training')
parser.add_argument('--lr', type=float, default=50, help='Learning rate')
parser.add_argument('--t_epochs', type=int, default=20, help='Number of training epochs')
parser.add_argument('--tau', type=float, default=1e-4, help='Norm regularization parameter')
parser.add_argument('--beta', type=float, default=1e-6, help='Noise parameter')
parser.add_argument('--data_std', type=float, default=4, help='Standard deviation of the data')
parser.add_argument('--granular', type=int, default=5, help='Granularity for saving particles')
parser.add_argument('--equiv_init', type=bool, default=False, help='Boolean flag for equivariant initialization of the student')
parser.add_argument('--teacher_mode', type=str, default='free', help='Teacher mode (e.g. "free", "weak", "strong")')
parser.add_argument('--n_reps', type=int, default=10, help='Number of times to repeat the experiments')
parser.add_argument('--n_p', type=int, default=5, help='Number of particles of the student network')
parser.add_argument('--n_t', type=int, default=5, help='Number of particles for the teacher_network')
parser.add_argument('--fix_teacher', type=bool, default=True, help='Boolean flag for Fixing the Teacher Particles')

args = parser.parse_args()

# Extract arguments
BATCH_SIZE = args.batch_size
LR = args.lr
T_EPOCHS = args.t_epochs
TAU = args.tau
BETA = args.beta
DATA_STD = args.data_std
GRANULAR = args.granular
EQUIV_INIT = args.equiv_init
TEACHER_MODE = args.teacher_mode
N_reps = args.n_reps
N_p = args.n_p
N_t = args.n_t
FIX_TEACHER = args.fix_teacher
path_save = args.path_save

#path_folder = 'drive/MyDrive/PaperExperiments/invariant-mean-field-neural-networks'

#import pathlib
#project_path = pathlib.Path(__file__).parent.resolve()
#project_path = os.path.abspath(path_folder)
#sys.path.insert(1, project_path)

import jax.numpy as jnp
from jax import vmap

import objax
from emlp.reps import V,sparsify_basis
from emlp.groups import SO,O,S,Z

import src.modules
from src.visualization import vis, particle_plot, plot_losses, particle_plot_animation
from src.theory_utils import equivariance_err, Wasserstein_Distance, rel_measure_distance
from src.modules import ShallowMLPNoLinearOut, FA_Model, SGD
from src.train_eval_utils import random_compare, training_loop
from src.utils import ExpData, CumData

# Defining Symmetries:
G = S(2)
repin = V(G)
repout = V(G)
rep_params = (repin>>repout)
P_params = rep_params.equivariant_projector()
base_params = rep_params.equivariant_basis()
G_generator = rep_params.rho_dense(G.discrete_generators[0])

# Vectorized application of maps
vP_params = vmap(lambda x: P_params@x)
vbase_params = vmap(lambda x: base_params@x)
vbase_paramsT = vmap(lambda x: base_params.T@x)

# The orbit maps are, unfortunately, restricted to this specific case of S(2)
vG_generator = vmap(lambda x: jnp.dot(G_generator,x))
vorbit = (lambda x: jnp.vstack([x, vG_generator(x)])) # This generates an array of "double the amount of particles", but with the complete orbit of each point.

scale_factor = 0.5


def create_model(N_p, activation_fn, mode="free", fixed_init = None):
  if mode in ["strong","strong-equivariant"]:
      model = ShallowMLPNoLinearOut(N=N_p, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=True)
      init_particles = model.get_particles() if fixed_init is None else fixed_init
  elif mode in ["weak", "weak-equivariant"]:
      model = ShallowMLPNoLinearOut(N=2*N_p, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)
      init_particles = vorbit(model.get_particles()) if fixed_init is None else vorbit(fixed_init)
  else:
      model = ShallowMLPNoLinearOut(N=N_p, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)
      init_particles = model.get_particles() if fixed_init is None else fixed_init
  model.set_particles(init_particles)
  return model


def create_V_DA_FA_EA(N_p, activation_fn, equiv_init=False, seed=0):
  # We consider a "free" network, NOT constrained to staying within the space of Equivariant parameters.
  model = ShallowMLPNoLinearOut(N=N_p, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)
  model_DA = ShallowMLPNoLinearOut(N=N_p, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)
  model_FA = FA_Model(ShallowMLPNoLinearOut(N=N_p, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False), G, repin, repout)
  model_EA = ShallowMLPNoLinearOut(N=N_p, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=True)

  # We initialize them all random, but with a common initial condition...
  common_initial_w = objax.random.normal(model.get_particles().shape, generator = objax.random.Generator(seed=seed))/4
  if equiv_init:
    # ... concentrated in E^G...
    common_initial_w = vP_params(common_initial_w)
  model.set_particles(common_initial_w)
  model_DA.set_particles(common_initial_w)
  model_FA.set_particles(common_initial_w)
  if equiv_init:
    model_EA.set_particles(vbase_paramsT(common_initial_w))
  else:
    model_EA.set_particles(vbase_paramsT(vP_params(common_initial_w)))
  return model, model_DA, model_FA, model_EA
  
  

def equiv_error(models, labels, sample):
  return dict([(label, equivariance_err(sample, model, G, repin, repout).item()) for model, label in zip(models, labels)])
  
def teacher_error(model_t, models_s, labels, sample):
    x = sample
    if not isinstance(models_s, list):
        models_s = [models_s]
    y_t, L_s = model_t(x), [model_s(x) for model_s in models_s]
    diff_t_s = [jnp.sqrt(((y_s-y_t)**2).mean(axis=1)) for y_s in L_s]
    return dict([(labels[j], diff.mean().item()) for j, diff in enumerate(diff_t_s)])
    
def new_rel_measure_distance(m1, m2, p=2, root = False, mode = "trace", return_Wasserstein=False):
    W_dist = Wasserstein_Distance(m1, m2, p=p, root=root)
    mom1 = ((m1**p).sum(axis=1)).mean()
    mom2 = ((m2**p).sum(axis=1)).mean()
    total_variation = (mom1 + mom2) if not root else jnp.power((mom1 + mom2), 1/p)
    return (2*W_dist/total_variation).item() if not return_Wasserstein else ((2*W_dist/total_variation).item(), W_dist)
    
def distances(particles1, particles2):
    RM = []
    WD = []
    for p1, p2 in zip(particles1, particles2):
      rm, wd = new_rel_measure_distance(p1, p2, root=False, mode="trace", return_Wasserstein=True)
      RM.append(rm)
      WD.append(wd)
    return RM, WD
    
def one_trial_V_DA_FA_EA(teacher_network, N_p, activation_fn, seed, train_params, equiv_init):
    BATCH_SIZE, LR, T_EPOCHS = train_params["BATCH_SIZE"], train_params["LR"], train_params["T_EPOCHS"]
    TAU, BETA, DATA_STD = train_params["TAU"], train_params["BETA"], train_params["DATA_STD"]
    GRANULAR, P_MAP =  train_params["GRANULAR"], train_params["P_MAP"]
    EPOCHS = T_EPOCHS*N_p
    PRINT_EVERY = EPOCHS/2 #5000
    LABELS = ["vanilla", "DA", "FA", "EA"]

    sym_teacher_network = FA_Model(teacher_network, G, repin, repout)

    model, model_DA, model_FA, model_EA =  create_V_DA_FA_EA(N_p, activation_fn, equiv_init=equiv_init, seed=seed)
    test_sample = objax.random.normal((100,2), stddev=DATA_STD)

    eq_error, dist_teacher, dist_sym_teacher = {}, {}, {}
    eq_error["start"] = equiv_error(models=[model, model_DA, model_FA,model_EA], labels=LABELS, sample=test_sample)
    dist_teacher["start"] = teacher_error(teacher_network, [model, model_DA, model_FA, model_EA], labels = LABELS, sample=test_sample)
    dist_sym_teacher["start"] = teacher_error(sym_teacher_network, [model, model_DA, model_FA, model_EA], labels = LABELS, sample=test_sample)

    train_losses_vanilla, particles_vanilla = training_loop(modelo=model, teacher_network=teacher_network, N=N_p, input_dim=repin.size(), BS=BATCH_SIZE, lr=LR, NUM_EPOCHS=EPOCHS, tau=TAU, beta=BETA, stddev=DATA_STD, print_every=PRINT_EVERY, proj_map=P_MAP, seed = SEED, return_particles=True)
    train_losses_DA, particles_DA = training_loop(modelo=model_DA, teacher_network=teacher_network, N=N_p, input_dim=repin.size(), BS=BATCH_SIZE, lr=LR, NUM_EPOCHS=EPOCHS, tau=TAU, beta=BETA, stddev=DATA_STD, print_every=PRINT_EVERY, proj_map=P_MAP, DA=(G, repin, repout), seed=SEED, return_particles=True)
    train_losses_FA, particles_FA = training_loop(modelo=model_FA, teacher_network=teacher_network, N=N_p, input_dim=repin.size(), BS=BATCH_SIZE, lr=LR, NUM_EPOCHS=EPOCHS, tau=TAU, beta=BETA, stddev=DATA_STD, print_every=PRINT_EVERY, proj_map=P_MAP, seed=SEED, return_particles=True)
    train_losses_EA, particles_EA = training_loop(modelo=model_EA, teacher_network=teacher_network, N=N_p, input_dim=repin.size(), BS=BATCH_SIZE, lr=LR, NUM_EPOCHS=EPOCHS, tau=TAU, beta=BETA, stddev=DATA_STD, print_every=PRINT_EVERY, proj_map=None, seed=SEED, return_particles=True)

    particles_vanilla, particles_DA, particles_FA, particles_EA = particles_vanilla[::(N_p//GRANULAR)], particles_DA[::(N_p//GRANULAR)], particles_FA[::(N_p//GRANULAR)], particles_EA[::(N_p//GRANULAR)]

    eq_error["end"] = equiv_error(models=[model, model_DA, model_FA,model_EA], labels=LABELS, sample=test_sample)
    dist_teacher["end"] = teacher_error(teacher_network, [model, model_DA, model_FA, model_EA], labels = LABELS, sample=test_sample)
    dist_sym_teacher["end"] = teacher_error(sym_teacher_network, [model, model_DA, model_FA, model_EA], labels = LABELS, sample=test_sample)

    train_losses = {LABELS[0]:train_losses_vanilla, LABELS[1]:train_losses_DA, LABELS[2]:train_losses_FA, LABELS[3]:train_losses_EA}
    particles = {LABELS[0]:particles_vanilla, LABELS[1]:particles_DA, LABELS[2]:particles_FA, LABELS[3]:particles_EA}

    return eq_error, dist_teacher, dist_sym_teacher, train_losses, particles
    
    
    
    
def RMD_comparisons(particles, equiv_init, teacher_mode):
    particles_vanilla, particles_DA = particles["vanilla"], particles["DA"]
    particles_FA, particles_EA = particles["FA"], particles["EA"]
    # Maybe calcular distancia de Wasserstein al teacher en términos de las partículas mismas...
    comparisons_RMD = {}
    comparisons_WD = {}
    vvP_params =  lambda x: vmap(vP_params)(jnp.array(x))
    vvorbit = lambda x: vmap(vorbit)(jnp.array(x))
    if equiv_init:
      comparisons_RMD["V vs. P(V)"], comparisons_WD["V vs. P(V)"] = distances(particles_vanilla, vvP_params(particles_vanilla))
      comparisons_RMD["DA vs. P(DA)"], comparisons_WD["DA vs. P(DA)"] = distances(particles_DA, vvP_params(particles_DA))
      comparisons_RMD["FA vs. P(FA)"], comparisons_WD["FA vs. P(FA)"] = distances(particles_FA, vvP_params(particles_FA))
      comparisons_RMD["V vs. DA"], comparisons_WD["V vs. DA"] = distances(particles_vanilla, particles_DA)
      comparisons_RMD["V vs. FA"], comparisons_WD["V vs. FA"] = distances(particles_vanilla, particles_FA)
      comparisons_RMD["V vs. EA"], comparisons_WD["V vs. EA"] = distances(particles_vanilla, particles_EA)
      comparisons_RMD["DA vs. FA"], comparisons_WD["DA vs. FA"] = distances(particles_DA, particles_FA)
      comparisons_RMD["DA vs. EA"], comparisons_WD["DA vs. EA"] = distances(particles_DA, particles_EA)
      comparisons_RMD["FA vs. EA"], comparisons_WD["FA vs. EA"] = distances(particles_FA, particles_EA)
    else:
      comparisons_RMD["V vs. G(V)"], comparisons_WD["V vs. G(V)"] = distances(particles_vanilla, vvorbit(particles_vanilla))
      comparisons_RMD["DA vs. G(DA)"], comparisons_WD["DA vs. G(DA)"] = distances(particles_DA, vvorbit(particles_DA))
      comparisons_RMD["FA vs. G(FA)"], comparisons_WD["FA vs. G(FA)"] = distances(particles_FA, vvorbit(particles_FA))
      comparisons_RMD["V vs. DA"], comparisons_WD["V vs. DA"] = distances(particles_vanilla, particles_DA)
      comparisons_RMD["V vs. FA"], comparisons_WD["V vs. FA"] = distances(particles_vanilla, particles_FA)
      comparisons_RMD["DA vs. FA"], comparisons_WD["DA vs. FA"] = distances(particles_DA, particles_FA)
      comparisons_RMD["P(FA) vs. EA"], comparisons_WD["P(FA) vs. EA"] = distances(vvP_params(particles_FA), particles_EA)
    return comparisons_RMD, comparisons_WD
    
activation_fn = objax.functional.sigmoid # objax.functional.tanh # objax.functional.selu # None


LABELS = ["vanilla", "DA", "FA", "EA"]
train_params = dict(
    BATCH_SIZE = BATCH_SIZE,
    LR = LR,
    T_EPOCHS = T_EPOCHS,
    TAU = TAU, # This is the "norm" regularization parameter
    BETA = BETA, # This is the "noise" parameter
    DATA_STD = DATA_STD,
    GRANULAR = GRANULAR,
    EQUIV_INIT=EQUIV_INIT,
    TEACHER_MODE=TEACHER_MODE,
    N_reps = N_reps)
train_params["P_MAP"] = vP_params if EQUIV_INIT else None

if FIX_TEACHER and N_t ==5:
  #Fixed Teacher
  if TEACHER_MODE == "strong":
    fixed_teacher_particles = (scale_factor)*jnp.array([[1,0],[0.5,1],[-0.5,0.3],[0,-1], [0.7, 0.7]])
  else:
    fixed_teacher_particles = (scale_factor)*jnp.array([[-1,0,0,0.5],[0.5,1,0,1],[-0.5,0.3,1,0],[0,-1,-0.5,1], [0.7, -0.7,0.5,0.7]])
else:
  fixed_teacher_particles = None
teacher_network = create_model(N_t, activation_fn, mode= TEACHER_MODE, fixed_init=fixed_teacher_particles)
train_params["TEACHER_FIXED"] = not fixed_teacher_particles is None


out_dict = ExpData(["eq_errors", "dist_to_teacher", "dist_to_sym_teacher",
                  "train_losses", "particles"], N_p, train_params)
dists_dict = ExpData(["comparisons_RMDs", "comparisons_WDs"], N_p, train_params)


for SEED in range(N_reps):
    trial_out = one_trial_V_DA_FA_EA(teacher_network, N_p, activation_fn, SEED, train_params, equiv_init=EQUIV_INIT)
    out_dict.append(trial_out)
    dists_dict.append(RMD_comparisons(particles=trial_out[-1], equiv_init=EQUIV_INIT, teacher_mode=TEACHER_MODE))

out_dict.save(path_save + "metrics")
dists_dict.save(path_save + "measure_distances")