# -*- coding: utf-8 -*-
"""
Created on Mon Apr 29 21:44:57 2024

@author: javie
"""
# training and evaluation
from tqdm import tqdm
import jax.numpy as jnp
from jax import vmap
import objax

from src.modules import SGD


def training_loop(modelo, teacher_network, N, input_dim, BS, lr, NUM_EPOCHS, tau=0, beta=0, print_every=1000, stddev=1, proj_map = None, DA=None, seed = 0, return_particles=False):
    """
    Generic training loop function.

    Parameters
    ----------
    modelo : objax.nn.Module
        model to be trained.
    teacher_network : objax.nn.Module
        teacher_network to be used for generating the training data.
    N : int
        Number of student particles in the model (could be deprecated).
    input_dim : int
        input space dimension.
    BS : int
        Batch Size for NN training.
    lr : float
        learning rate for the optimizer.
    NUM_EPOCHS : int
        number of training epochs.
    tau : float, optional
        quadratic/ridge regularization parameter. The default is 0.
    beta : float, optional
        training noise parameter. The default is 0.
    print_every : int, optional
        how often the learning rate will be halved. The default is 1000.
    stddev : float, optional
        standard deviation for the gaussian input. The default is 1.
    proj_map : Callable batched transformation, optional
        projection map to be applied to the noise. The default is None.
    seed : int, optional
        Seed for the random data generator. The default is 0.

    Returns
    -------
    list
        train losses obtained after each training epoch.

    """
    opt = SGD(modelo.vars())
  
    @objax.Jit
    @objax.Function.with_vars(modelo.vars())
    def get_particles_model():
        return modelo.get_particles()

    @objax.Jit
    @objax.Function.with_vars(modelo.vars())
    def loss(x, y, tau = 0):
        yhat = modelo(x)
        particles = get_particles_model()
        norm_penalization = ((particles**2).sum(axis = 1)).mean() #(1/modelo.N)*jnp.linalg.norm(modelo.linear.w.value)**2
        return ((yhat-y)**2).mean() + tau*norm_penalization #/model[0].weight.shape[0] #objax.functional.loss.mean_squared_error(out, labels).mean()reg_u = 1e-2
  
    gradient_loss = objax.GradValues(loss, modelo.vars())
  
    @objax.Jit
    @objax.Function.with_vars(modelo.vars() + opt.vars())
    def train(x, y, lr, tau=0, noise=0, std_noise = 0):
        g, v = gradient_loss(x, y, tau=tau)  # Compute gradients and loss
        opt(lr=lr, gradients=g)                       # Apply SGD
        for ref in opt.refs:
          ref.value += std_noise*noise
  
        return v                         # Return loss value
  
    train_losses = []
    if return_particles:
      particles = [get_particles_model()]
    X = objax.random.normal((NUM_EPOCHS,BS, input_dim), stddev=stddev, generator=objax.random.Generator(seed=seed))
    apply_DA = not DA is None
    if apply_DA:
        G, repin, repout = DA 
        @objax.Jit
        @objax.Function.with_vars(modelo.vars())
        def augment_in(x):
          return jnp.vstack([x, vmap(lambda x: repin.rho_dense(G.discrete_generators[0])@x)(x)])
        @objax.Jit
        @objax.Function.with_vars(modelo.vars())
        def augment_out(y):
          return jnp.vstack([y, vmap(lambda x: repout.rho_dense(G.discrete_generators[0])@x)(y)])

    if proj_map is None:
      NOISE = objax.random.normal((NUM_EPOCHS, *modelo.get_param_shape()), generator=objax.random.Generator(seed=seed))
    else:
      NOISE = vmap(lambda x: proj_map(x))(objax.random.normal((NUM_EPOCHS, *modelo.get_particles().shape), generator=objax.random.Generator(seed=seed))).reshape(NUM_EPOCHS, -1)
    for i in tqdm(range(NUM_EPOCHS)):
      x = X[i]
      y = teacher_network(x)

      if apply_DA:
          x = augment_in(x)
          y = augment_out(y)

      std_noise = jnp.sqrt(2*beta*lr/N)
      train_losses.append(train(x, y, lr=lr, tau=tau, noise=NOISE[i], std_noise=std_noise))
      if return_particles:
        particles.append(get_particles_model())
      if i%print_every == 0:
        lr = lr/2
        pass
  
    return train_losses if not return_particles else (train_losses, particles)


def training_loop_oneL(modelo, teacher_network, N, input_dim, BS, lr, NUM_EPOCHS, tau=0, beta=0, print_every=1000, stddev=1, proj_map = None, DA=None, seed = 0, return_particles=False):
    # training loop adapted for truly single-layer NNs
    opt = SGD(modelo.vars())
    #print(opt.refs)

    @objax.Jit
    @objax.Function.with_vars(modelo.vars())
    def get_particles_model():
        return modelo.get_particles()

    @objax.Jit
    @objax.Function.with_vars(modelo.vars())
    def loss(x, y, tau = 0):
        yhat = modelo(x)
        particles = get_particles_model()
        norm_penalization = ((particles**2).sum(axis = 1)).mean() #(1/modelo.N)*jnp.linalg.norm(modelo.linear.w.value)**2
        return ((yhat-y)**2).mean() + tau*norm_penalization #/model[0].weight.shape[0] #objax.functional.loss.mean_squared_error(out, labels).mean()reg_u = 1e-2

    gradient_loss = objax.GradValues(loss, modelo.vars())
    random_generator = objax.random.Generator(seed=seed)

    use_proj = not proj_map is None
    if use_proj:
        if not isinstance(proj_map, list):
            proj_map = [proj_map] 
        p_shapes = [get_particles_model().shape]
        shapes = [modelo.get_param_shape()]

    @objax.Jit
    @objax.Function.with_vars(modelo.vars() + opt.vars()+ random_generator.vars())
    def train(x, y, lr, tau=0, std_noise = 0):
        g, v = gradient_loss(x, y, tau=tau)  # Compute gradients and loss
        opt(lr=lr, gradients=g)                       # Apply SGD
        for j, ref in enumerate(opt.refs):
          #print(ref.value.shape)
          noise = objax.random.normal(ref.value.shape, generator=random_generator)
          if use_proj:
            noise = proj_map[j](noise.reshape(p_shapes[j])).reshape(shapes[j])
          ref.value += std_noise*noise

        return v                         # Return loss value

    train_losses = []
    X = objax.random.normal((NUM_EPOCHS,BS, input_dim), stddev=stddev)
    if return_particles:
      particles = [get_particles_model()]
    apply_DA = not DA is None
    if apply_DA:
        G, repin, repout = DA 
        @objax.Jit
        @objax.Function.with_vars(modelo.vars())
        def augment_in(x):
          return jnp.vstack([x, vmap(lambda x: repin.rho_dense(G.discrete_generators[0])@x)(x)])
        @objax.Jit
        @objax.Function.with_vars(modelo.vars())
        def augment_out(y):
          return jnp.vstack([y, vmap(lambda x: repout.rho_dense(G.discrete_generators[0])@x)(y)])


    for i in tqdm(range(NUM_EPOCHS)):
      x = X[i]
      y = teacher_network(x)
      if apply_DA:
          x = augment_in(x)
          y = augment_out(y)
      std_noise = jnp.sqrt(2*beta*lr/N)
      train_losses.append(train(x, y, lr=lr, tau=tau, std_noise=std_noise))
      if return_particles:
        particles.append(get_particles_model())
      if i%print_every == 0:
        lr = lr/2
        pass
    return train_losses if not return_particles else (train_losses, particles)



def training_loop_aug(modelo, teacher_network, N, input_dim, BS, lr, NUM_EPOCHS, tau=0, beta=0, print_every=1000, stddev=1, proj_map = None, DA = False, FA=False, seed = 0):
    """
    Generic training loop function.

    Parameters
    ----------
    modelo : objax.nn.Module
        model to be trained.
    teacher_network : objax.nn.Module
        teacher_network to be used for generating the training data.
    N : int
        Number of student particles in the model (could be deprecated).
    input_dim : int
        input space dimension.
    BS : int
        Batch Size for NN training.
    lr : float
        learning rate for the optimizer.
    NUM_EPOCHS : int
        number of training epochs.
    tau : float, optional
        quadratic/ridge regularization parameter. The default is 0.
    beta : float, optional
        training noise parameter. The default is 0.
    print_every : int, optional
        how often the learning rate will be halved. The default is 1000.
    stddev : float, optional
        standard deviation for the gaussian input. The default is 1.
    proj_map : Callable batched transformation, optional
        projection map to be applied to the noise. The default is None.
    DA : tuple containing (emlp.group, emlp.rep, eml.rep), optional
        If not None, tuple containing group, repin and repout for applying DA. The default is None.
    FA : bool, optional
        If True, the model will be assumed to be an "FA" version. The default is None.
    seed : int, optional
        Seed for the random data generator. The default is 0.

    Returns
    -------
    list
        train losses obtained after each training epoch.

    """
    opt = SGD(modelo.vars())
  
    @objax.Jit
    @objax.Function.with_vars(modelo.vars())
    def loss(x, y, tau = 0, FA=FA):
        yhat = modelo(x)
        particles = (modelo.linear.w.value).reshape(modelo.N, -1) if not FA else (modelo.model.linear.w.value).reshape(modelo.N, -1)
        norm_penalization = ((particles**2).sum(axis = 1)).mean() #(1/modelo.N)*jnp.linalg.norm(modelo.linear.w.value)**2
        return ((yhat-y)**2).mean() + tau*norm_penalization #/model[0].weight.shape[0] #objax.functional.loss.mean_squared_error(out, labels).mean()reg_u = 1e-2
  
    gradient_loss = objax.GradValues(loss, modelo.vars())
  
    @objax.Jit
    @objax.Function.with_vars(modelo.vars() + opt.vars())
    def train(x, y, lr, tau=0, noise=0, std_noise = 0):
        g, v = gradient_loss(x, y, tau=tau)  # Compute gradients and loss
        opt(lr=lr, gradients=g)                       # Apply SGD
        for ref in opt.refs:
          ref.value += std_noise*noise
  
        return v                         # Return loss value
  
    train_losses = []
  
    X = objax.random.normal((NUM_EPOCHS,BS, input_dim), stddev=stddev, generator=objax.random.Generator(seed=seed))
    
    apply_DA = not DA is None
    if apply_DA:
        G, repin, repout = DA  
  
    if proj_map is None:
      p_shape = modelo.linear.w.shape if not FA else modelo.model.linear.w.shape
      NOISE = objax.random.normal((NUM_EPOCHS, *p_shape), generator=objax.random.Generator(seed=seed))
    else:
      p_shape = modelo.linear.w.reshape(N, -1).shape if not FA else modelo.model.linear.w.reshape(N, -1).shape
      NOISE = vmap(lambda x: proj_map(x))(objax.random.normal((NUM_EPOCHS, *p_shape), generator=objax.random.Generator(seed=seed))).reshape(NUM_EPOCHS, -1)
    for i in tqdm(range(NUM_EPOCHS)):
      x = X[i]
      y = teacher_network(x)
  
      if apply_DA:
          x = jnp.vstack([x, vmap(lambda x: repin.rho_dense(G.discrete_generators[0])@x)(x)])
          y = jnp.vstack([y, vmap(lambda x: repout.rho_dense(G.discrete_generators[0])@x)(y)])
      std_noise = jnp.sqrt(2*beta*lr/N)
      train_losses.append(train(x, y, lr=lr, tau=tau, noise=NOISE[i], std_noise=std_noise))
      if i%print_every == 0:
        lr = lr/2
        pass
  
    return train_losses


def training_loop_L(modelo, teacher_network, N, input_dim, BS, lr, NUM_EPOCHS, tau=0, beta=0, print_every=1000, stddev=1, proj_map = None, DA=None, seed = 0):
    # training loop adapted for truly single-layer NNs
    opt = SGD(modelo.vars())
    #print(opt.refs)
    @objax.Jit
    @objax.Function.with_vars(modelo.vars())
    def loss(x, y, tau = 0):
        yhat = modelo(x)
        particles1, particles2 = modelo.get_particles()
        particles = jnp.hstack([particles1, particles2])
        norm_penalization = ((particles**2).sum(axis = 1)).mean() #(1/modelo.N)*jnp.linalg.norm(modelo.linear.w.value)**2
        return ((yhat-y)**2).mean() + tau*norm_penalization #/model[0].weight.shape[0] #objax.functional.loss.mean_squared_error(out, labels).mean()reg_u = 1e-2

    gradient_loss = objax.GradValues(loss, modelo.vars())
    random_generator = objax.random.Generator(seed=seed)

    use_proj = not proj_map is None
    if use_proj:
        if not isinstance(proj_map, list):
            proj_map = [proj_map] 
        p_shapes = [p.shape for p in modelo.get_particles()]
        shapes = modelo.get_param_shape()

    @objax.Jit
    @objax.Function.with_vars(modelo.vars() + opt.vars()+ random_generator.vars())
    def train(x, y, lr, tau=0, std_noise = 0):
        g, v = gradient_loss(x, y, tau=tau)  # Compute gradients and loss
        opt(lr=lr, gradients=g)                       # Apply SGD
        for j, ref in enumerate(opt.refs):
          #print(ref.value.shape)
          noise = objax.random.normal(ref.value.shape, generator=random_generator)
          if use_proj:
            noise = proj_map[j](noise.reshape(p_shapes[j])).reshape(shapes[j])
          ref.value += std_noise*noise

        return v                         # Return loss value

    train_losses = []
    X = objax.random.normal((NUM_EPOCHS,BS, input_dim), stddev=stddev)

    apply_DA = not DA is None
    if apply_DA:
        G, repin, repout = DA 
        @objax.Jit
        @objax.Function.with_vars(modelo.vars())
        def augment_in(x):
          return jnp.vstack([x, vmap(lambda x: repin.rho_dense(G.discrete_generators[0])@x)(x)])
        @objax.Jit
        @objax.Function.with_vars(modelo.vars())
        def augment_out(y):
          return jnp.vstack([y, vmap(lambda x: repout.rho_dense(G.discrete_generators[0])@x)(y)])


    for i in tqdm(range(NUM_EPOCHS)):
      x = X[i]
      y = teacher_network(x)
      if apply_DA:
          x = augment_in(x)
          y = augment_out(y)
      std_noise = jnp.sqrt(2*beta*lr/N)
      train_losses.append(train(x, y, lr=lr, tau=tau, std_noise=std_noise))
      if i%print_every == 0:
        lr = lr/2
        pass

    return train_losses


######### EVALUATION

def random_compare(model_t, models_s, shape=(4,2), r_method=objax.random.uniform):
    """

    Parameters
    ----------
    model_t : objax.nn.module
        teacher network to be evaluated.
    models_s : list of objax.nn.modules
        student networks to be compared to the teacher network.
    shape : tuple, optional
        shape of the data to be simulated. If shape[0] > 10 only the end result
        will be printed. The default is (4,2).
    r_method : Callable, optional
        method for generating the random sample of the given shape. The default is objax.random.uniform.

    Returns
    -------
    None. But prints the teacher output, models output, and the difference 
    between them (as well as the aggregated version).

    """
    x = r_method(shape)
    if not isinstance(models_s, list):
        models_s = [models_s]
    y_t, L_s = model_t(x), [model_s(x) for model_s in models_s]
    diff_t_s = [jnp.sqrt(((y_s-y_t)**2).mean(axis=1)) for y_s in L_s]
    if shape[0] <=10:
        print("Input: ", x)
        print("Teacher Output: ", y_t)
        print("Model Output: ", L_s)
        print("Difference (norm): ", diff_t_s)
    print("Difference (aggregated): ", [diff.mean().item() for diff in diff_t_s])