import matplotlib

from library.rnn_architectures.tensor_rank import *
from library.helper import control, classes, functions
from import_data import *
from library.plotting import utils
from library.sde.systems import SDE
import reach_plot

import torch
import torch.nn as nn
from matplotlib import pyplot as plt
import numpy as np
from slicetca.run.utils import block_mask
from datetime import datetime
import os


def train(optimize, sde_preparatory, sde_execution, net, condition_wise_map, rnn_to_data,
          neural_data, full_neural_data, data,
          train_mask, test_mask, parameters, directory, epoch_id, epochs, optim,
          trials_dimension, epoch_starts, trial_ids, cmap, epoch_labels,
          ts_execution, ts_preparatory, ts_experiment, normalized_condition, condition_np, cmap_per_condition, device):

    rows, columns = 5, 5
    fig, axs, ax_giga = reach_plot.get_figure(rows, columns, parameters, directory)

    losses = []
    for training_iteration in range(parameters['training_iterations']):

        sde_preparatory.build_parameterization()

        # ========= Evaluation =========
        xs_prep = functions.sdeint_aaeh(sde_preparatory, sde_preparatory.get_initial_state(trials_dimension), ts_preparatory)
        xs_prep_cut = sde_preparatory.cut_states(xs_prep)
        if parameters['control_execution']:
            xs_exec = functions.sdeint_aaeh(sde_execution, xs_prep[-1], ts_execution)
        else:
            xs_exec = functions.sdeint_aaeh(sde_execution, xs_prep_cut[0][-1], ts_execution)

        xs_exec_cut = sde_execution.cut_states(xs_exec)
        
        rnn_mp = torch.cat([xs_prep_cut[0], xs_exec_cut[0][1:]])

        data_estimate = rnn_to_data(net.activation(rnn_mp if parameters['fit_preparatory'] else xs_exec_cut[0]))

        # ========= Loss =========
        mask_trial = (torch.rand(trials_dimension, device=device) < (1-parameters['fraction_masked']))
        full_mask = train_mask & mask_trial.unsqueeze(0).unsqueeze(-1)

        l = torch.mean(((data_estimate - neural_data)[full_mask])**2)
        l_test = torch.mean(((data_estimate - neural_data)[test_mask])**2)
        l_total = torch.mean(((data_estimate - neural_data)[train_mask])**2)

        # ========= Regularization =========
        reg_fn = torch.abs if parameters['regularization_function'] == 'abs' else torch.square

        l_reg_rnn = reg_fn(rnn_mp[:, mask_trial]).mean()
        l_reg_control_prep = reg_fn(condition_wise_map.h(xs_prep_cut[1]))[:, mask_trial].mean()
        l_reg_control_exec = reg_fn(condition_wise_map.h(xs_exec_cut[1]))[:, mask_trial].mean() if parameters['control_execution'] else torch.tensor(0)

        l_reg_total = l_reg_control_prep + l_reg_control_exec + l_reg_rnn

        l += l_reg_total * parameters['regularization']

        losses.append([l_total.item(), l_test.item()])

        if training_iteration % parameters['test_freq'] == 0:
            with torch.no_grad():

                print(training_iteration, 'l_total:', l_total.item(), 'l_test:', l_test.item(), 'l_reg:', l_reg_total.item(),
                      'l_reg_control_prep:', l_reg_control_prep.item(), 'l_reg_control_exec:', l_reg_control_exec.item(),
                      'l_reg_rnn:', l_reg_rnn.item())

                for i in axs:
                    for j in i:
                        if j is not None: j.cla()
                ax_giga.cla()

                # ========= Arrays to plot ==========
                ts_experiment_prep = ts_experiment[:parameters['preparatory_steps']+1]

                rnn_activity = net.activation(rnn_mp).numpy(force=True)
                rnn_mp = rnn_mp.numpy(force=True)

                controls = [xs_prep_cut[1]]
                if parameters['control_execution']: controls.append(xs_exec_cut[1][1:])
                controls = condition_wise_map(torch.cat(controls)).numpy(force=True)
                
                go_cue, mov_onset = 0, 200

                steps = 1
                min_trial = 0
                max_trial = 150
                max_time = len(rnn_mp)

                adaptation_color = np.array((0.88, 0.71, 0.99))
                baseline_color = np.array((0.3, 0.5, 1.0))
                washout_color = baseline_color

                cmaps_trial_factor_3d = [utils.get_cmap_interpolated(c * 0.1 + 0.9, c) for c in
                                         [baseline_color, adaptation_color, washout_color]]

                # ========= Projection ==========
                components = net.get_components(observation=True)

                projection = components[2].numpy(force=True)
                projection = projection / np.linalg.norm(projection, axis=1)[:, np.newaxis]
                projection = projection[:parameters['rank']]

                grad = np.concatenate(
                    [np.ones(epoch_starts[0]), np.linspace(0.7, 0.7, epoch_starts[1] - epoch_starts[0]),
                     np.linspace(0, 1, trials_dimension - epoch_starts[1])])

                projection = np.concatenate([projection, 10**-6*np.ones((max(0,3-parameters['rank']), parameters['rec_dim']))])
                x_projected = rnn_mp @ projection.T  #

                # =========== Plotting ===========

                # Big plot
                reach_plot.giga_projection(ax_giga, rnn_mp, x_projected, normalized_condition, cmap_per_condition, cmap, grad,
                    epoch_starts, max_time, min_trial, max_trial, steps, parameters)

                # Projection over time
                reach_plot.projection_over_time(axs[4], x_projected, condition_np, normalized_condition, grad, steps,
                                                ts_experiment, cmap_per_condition, go_cue, mov_onset, parameters)

                # PCA on data
                reach_plot.data_pca(axs[2][0], full_neural_data, normalized_condition, cmap_per_condition, grad,
                                    min_trial, max_trial, steps)

                # Sorted by peak
                reach_plot.sorted_by_peak(axs[0][0], rnn_activity, ts_experiment, go_cue, mov_onset)

                # Single neuron
                reach_plot.single_neurons(axs[1][0], ts_experiment, rnn_activity, normalized_condition, cmap_per_condition, go_cue, mov_onset)

                # Controls
                reach_plot.controls_over_time(axs[0][1], ts_experiment, controls, normalized_condition, ts_experiment_prep, cmap, go_cue, mov_onset, parameters)

                # Trial factors
                components_latent = net.get_components(observation=False)
                reach_plot.trial_factors(axs[3], components, components_latent, normalized_condition,
                                         trial_ids.numpy(force=True), cmap, epoch_starts, adaptation_color, epoch_labels,
                                         parameters, data, trial_ids)

                # Trial factors in 3D
                trial_components_latent = components_latent[0].detach().cpu().numpy()
                reach_plot.trial_factors_3d(axs[2][1], trial_components_latent, cmaps_trial_factor_3d,
                                            data, epoch_labels, parameters)

                # Loss
                losses_array = np.array(losses) / np.var(neural_data.numpy(force=True))
                reach_plot.loss(axs[2][3], neural_data, losses_array, parameters)

                # Eigenspectrum
                W_obs = net.construct_weight(observation=False).numpy(force=True)
                reach_plot.eigenspectrum(axs[2][2], W_obs, epoch_id, cmaps_trial_factor_3d)

                # Initial condition
                reach_plot.initial_condition(axs[1][1], xs_exec_cut, epoch_id.numpy(force=True), cmap, normalized_condition,
                                             condition_np, epoch_labels)

                # Vector fields
                axs_temp = [axs[1,4], axs[2,4]]
                reach_plot.vector_fields(axs_temp, net, projection, adaptation_color, components, epochs, device, parameters)

                # Text
                ax = axs[0,4]
                reach_plot.text(ax, parameters)

                plt.draw()
                plt.pause(1) #Increase for some CPU configs

                plt.savefig(directory + '/' + directory.split('/')[-1] + '.pdf')

                if training_iteration == 0: plt.savefig(directory + '/' + directory.split('/')[-1] + '-0.pdf')

                torch.save(sde_preparatory.state_dict(), directory + '/model.pt')
                torch.save(rnn_to_data.state_dict(), directory + '/map.pt')

                # Criteria to stop optim
                if len(losses_array)>=parameters['steps_std_convergence']:
                    print(losses_array[-parameters['steps_std_convergence']:, 1].std())

        if len(losses_array)>=parameters['steps_std_convergence']:
            if losses_array[-parameters['steps_std_convergence']:,1].std()<parameters['min_std_convergence']:
                break

        if optimize:
            optim.zero_grad()
            l.backward()
            sde_preparatory.backward_parameterization()
            optim.step()

    plt.close(fig)

    return losses_array[-parameters['steps_std_convergence']:].mean(axis=0)


def main(load_directory, directory, parameters, device):

    torch.manual_seed(parameters['seed'])
    np.random.seed(parameters['seed'])

    neural_data, data, angles, condition, times, epochs, trial_ids = import_data(parameters['data_directories'])

    prep, start, stop = np.argmax(times >= 0)-parameters['preparatory_steps'], np.argmax(times >= 0), np.argmax(times)
    ts_experiment = times[prep:stop]*100

    full_neural_data = torch.tensor(neural_data, device=device)
    neural_data = neural_data[start:stop] if not parameters['fit_preparatory'] else neural_data[prep:stop]

    time_dimension, trials_dimension, neurons_dimension = neural_data.shape
    condition_dimension = len(np.unique(condition))

    epoch_labels = ['BL', 'AD', 'WO']
    epoch_starts = [np.argmax(data['epoch']==e) for e in epoch_labels][1:]

    print('Dimensions - Time:', time_dimension, ' Trial:', trials_dimension, 'Neuron:', neurons_dimension, 'Condition:', condition_dimension)

    normalized_condition = condition / (np.max(condition)+1)
    condition_np = normalized_condition
    neural_data, condition, trial_ids = torch.tensor(neural_data, device=device), torch.tensor(condition, device=device, dtype=torch.long), torch.tensor(trial_ids, device=device)

    half_train_block, half_test_block = [10,10,1], [5,5,1]
    nd_shape = np.array(neural_data.shape)
    number_blocks = int(parameters['fraction_test']*np.prod(nd_shape)/np.prod(1+2*np.array(half_train_block)))
    train_mask, test_mask = block_mask(list(nd_shape), half_train_block, half_test_block, number_blocks, device=device)

    print('Mask train proportion:', train_mask.float().mean().item())

    cmap = matplotlib.colormaps['hsv']
    cmap_per_condition = [utils.get_cmap_black(cmap(i)) for i in np.unique(condition_np)]

    activation = torch.tanh

    kernel = RationalQuadratic(parameters['l'], parameters['sigma'])

    epoch_id = torch.tensor(sum([ei * (data['epoch']==e) for ei, e in enumerate(epoch_labels)]), device=device, dtype=torch.long)

    if parameters['in_space_control']: parameters['control_dim'] = parameters['rank']

    net = TensorRankSmooth(parameters['rec_dim'], parameters['rank'],
                           trial_ids, None, kernel, epoch=epoch_id if parameters['discontinuous_covariance'] else None,
                           activation=activation, std_observation=parameters['sigma_observation'],
                           optimize_input_maps=True, in_space_inputs=parameters['in_space_control'],
                           condition=condition if parameters['condition_specific_init'] else None,
                           bias=parameters['bias'],
                           in_dims=(parameters['control_dim'],),
                           time_constant=parameters['time_constant'],
                           noise=parameters['noise'], device=device)

    activation_control = torch.relu
    controller = control.TrialRankController(parameters['control_hidden_dim'], trials_dimension, condition_dimension,
                                          parameters['control_dnn_dim'],
                                          time_constant=parameters['time_constant_control'],
                                          in_dims=[],
                                          init='optimized', device=device, activation=activation_control)

    condition_wise_map = classes.ConditionWiseControlMap(condition, parameters['control_hidden_dim'],
                                                         parameters['control_dim'], activation=activation,
                                                         device=device)

    graph = np.array([[1, condition_wise_map],
                      [0, 1]])
    sde_preparatory = SDE([net, controller], graph)

    if parameters['control_execution']:
        graph = np.array([[1, condition_wise_map],
                          [0, 1]])
        sde_execution = SDE([net, controller], graph)
    else:
        sde_execution = SDE([net], [[1]])

    if parameters['orthogonal_decoder']:
        rnn_to_data = classes.OrthogonalScaled(parameters['rec_dim'], neurons_dimension,
                                               bias=parameters['decoder_bias'],
                                               scaling=False,
                                               scaling_neuron_wise=False, device=device)
    else:
        rnn_to_data = nn.Linear(parameters['rec_dim'], neurons_dimension,
                                bias=parameters['decoder_bias'], device=device)

    param = set(list(sde_preparatory.parameters()) + list(rnn_to_data.parameters()))

    optim = torch.optim.Adam(param, lr=parameters['learning_rate'])

    ts = torch.linspace(0, parameters['duration'], stop-start)
    ts_preparatory = torch.linspace(0, (parameters['preparatory_steps']+1)*(ts[1]-ts[0]), start-prep+1)
    ts_execution = ts

    if load_directory != '.':
        sde_preparatory.load_state_dict(torch.load(load_directory + '/model.pt', map_location=device), strict=False)
        rnn_to_data.load_state_dict(torch.load(load_directory + '/map.pt', map_location=device))

    return train(optimize, sde_preparatory, sde_execution, net, condition_wise_map, rnn_to_data,
          neural_data, full_neural_data, data,
          train_mask, test_mask, parameters, directory, epoch_id, epochs, optim,
          trials_dimension, epoch_starts, trial_ids, cmap, epoch_labels,
          ts_execution, ts_preparatory, ts_experiment, normalized_condition, condition_np, cmap_per_condition, device)

def cross_validation():

    cv_parameters = {'rec_dim' : [200], 'rank' : [1, 2, 3, 4, 5, 10, 20]}

    seeds = [10, 11, 12]

    load_directory = '.'

    directory = functions.make_directory('cv')
    parameters = functions.load_yaml(load_directory, directory)

    parameters['data_directories'] = parameters['data_directories'][parameters['session']]

    keys = list(cv_parameters.keys())

    cv_losses = np.full((len(cv_parameters[keys[0]]), len(cv_parameters[keys[1]]), len(seeds)), np.nan)
    fig = plt.figure(figsize=(5, 4), constrained_layout=True)
    ax = fig.add_subplot()
    cmap = matplotlib.colormaps['Set2']

    plt.show(block=False)

    for p1i, p1 in enumerate(cv_parameters[keys[0]]):
        for p2i, p2 in enumerate(cv_parameters[keys[1]]):
            for si, s in enumerate(seeds):
                print(keys[0], p1, keys[1], p2, 'seed', s)
                parameters[keys[0]] = p1
                parameters[keys[1]] = p2
                parameters['seed'] = s

                now = datetime.now()
                date_time = now.strftime("%d-%m-%Y_%H_%M_%S")
                directory_run = directory + '/' + date_time
                if not os.path.exists(directory_run):
                    os.makedirs(directory_run)
                try:
                    l = main('.', directory_run, parameters, device)[1]
                except Exception as e:
                    print('\n\n')
                    print(e)
                    print('\n\n')

                    l = np.nan
                cv_losses[p1i, p2i, si] = l

                np.save(directory+'/cv_grid.npy', cv_losses)

                cv_losses_mean = cv_losses.mean(axis=-1)
                cv_loss_std = cv_losses.std(axis=-1)

                utils.set_bottom_axis(ax)

                ax.cla()
                for pi, p in enumerate(cv_parameters[keys[0]]):
                    ax.errorbar(cv_parameters[keys[1]], cv_losses_mean[pi], cv_loss_std[pi], fmt='-o',
                                label=keys[0]+'='+str(p),color=cmap(pi), linewidth=1.5)

                ax.set_xticks(cv_parameters[keys[1]])
                ax.set_xlabel(keys[1])
                ax.set_ylabel('MSE')
                ax.set_title('CV loss')
                ax.legend()

                plt.savefig(directory+'/cv_' + directory.split('/')[-1] +'.pdf')
                plt.draw()
                plt.pause(0.1)

def single_fit():

    load_directory = '.'

    directory_main = functions.make_directory()
    parameters_main = functions.load_yaml(load_directory, directory_main)

    parameters_main['data_directories'] = parameters_main['data_directories'][parameters_main['session']]

    main(load_directory, directory_main, parameters_main, device)


if __name__=='__main__':

    device = ('cuda' if torch.cuda.is_available() else 'cpu')

    utils.set_font()

    optimize = True

    #cross_validation()
    single_fit()
