"""
Run whole simulation for network model.

    - initial training on random reach data (following what's defined in protocol')
    - biologically plausible adaptation run
"""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict, defaultdict
import itertools

import sys
import copy

from modules.utils import get_nullspace_relative_norm, get_jacobian, get_jacobian_2

sys.path.append("modules")

from data_set import *
from model_def import RNN
from utils import *


def main(savname):

    savname = savname + "/"

    params = np.load(savname + "params.npy", allow_pickle=True).item()
    protocol = params["model"]["protocol"]

    # SETUP SIMULATION #################
    rand_seed = params["model"]["rand_seed"]
    np.random.seed(rand_seed)
    torch.manual_seed(rand_seed)

    # GPU usage #################
    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    params["model"].update({"dtype": dtype, "device": device})

    # DATASET #################
    if params["data"]["dataset_name"] == "Reaching":
        dataset = Reaching()
    elif params["data"]["dataset_name"] == "Sinewave":
        dataset = Sinewave()
        params["model"]["input_dim"] = params["data"]["Sinewave"]["input_dim"]
        params["model"]["output_dim"] = params["data"]["Sinewave"]["output_dim"]
    elif params["data"]["dataset_name"] == "Add":
        dataset = Add()
        params["model"]["input_dim"] = params["data"]["Add"]["input_dim"]
        params["model"]["output_dim"] = params["data"]["Add"]["output_dim"]
    elif params["data"]["dataset_name"] == "Add_zc":
        dataset = Add_zc()
        params["model"]["input_dim"] = params["data"]["Add"]["input_dim"]
        params["model"]["output_dim"] = params["data"]["Add"]["output_dim"]

    # SETUP MODEL #################

    if params["model"]["nlayers"] == 1:
        model = RNN(
            params["model"]["input_dim"],
            params["model"]["output_dim"],
            params["model"]["n"],
            dtype,
            params["model"]["dt"],
            params["model"]["tau"],
            fb_delay=params["model"]["fb_delay"],
            fb_density=params["model"]["fb_density"],
            recurrent=params["model"]["recurrent"],
            error_type=params["model"]["error_type"],
            error_detach=params["model"]["error_detach"],
        )
        init_state = copy.deepcopy(list(model.parameters()))
        print(model.rnn.weight_ih_l0.shape)
    else:
        print("alternaitves not implemented!")

    # TO CUDA OR NOT TO CUDA #################
    if dtype == torch.cuda.FloatTensor:
        model = model.cuda()

    # TASK STUFF #################
    if params["data"]["dataset_name"] == "Reaching":
        model.pos_err = True

    # SETUP OPTIMIZER #################
    criterion = nn.MSELoss(reduction="none")
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), lr=params["model"]["lr"]
    )


    # START INITIAL (PRE)TRAINING #################
    for phase in range(len(protocol)):
        print("\n####### PHASE %d #######" % phase)
        # get info from protocol
        ph_ntrials = protocol[phase][1]
        ph_datname = protocol[phase][0]
        batch_size = params["model"]["batch_size"]

        # create test data

        target_tr, stimulus_tr, pert_tr, tids_tr, stim_ref_tr = dataset.prepare_pytorch(
            params, 'random', test_set=True
        )

        target_trp, stimulus_trp, pert_trp, tids_trp, stim_ref_trp = dataset.prepare_pytorch(
            params, 'random_pushed', test_set=True
        )

        # ACTUAL TRAINING STARTS
        lc = []
        model.train()

        # PREPARE TO RECORD #########


        test_losses_r = []
        test_losses_r_nofb = []
        test_losses_rp = []


        for epoch in range(ph_ntrials):

            # create trainigng data
            target, stimulus, pert, tids, stim_ref = dataset.prepare_pytorch(
                params, "random_pushed", 1, batch_size
            )   

            # PREP WORK #######
            optimizer.zero_grad()
            loss = torch.tensor(0.0).to(model.rnn.weight_ih_l0.device)
            toprint = OrderedDict()

            with torch.no_grad():
                loss_test = dataset.test_dt(model, stimulus_tr, pert_tr, stim_ref_tr, fb_in=True) # with feedback 
                test_losses_r.append(loss_test.detach())
                loss_test = dataset.test_dt(model, stimulus_tr, pert_tr, stim_ref_tr, fb_in=False) # without feedbakc
                test_losses_r_nofb.append(loss_test.detach())
                loss_test = dataset.test_dt(model, stimulus_trp, pert_trp, stim_ref_trp, fb_in=True) # with feedback + pushed
                test_losses_rp.append(loss_test.detach())

            # add regularization
            # term 1: parameters
            regin = params["model"]["alpha1"] * model.rnn.weight_ih_l0.norm(2)
            regout = params["model"]["alpha1"] * model.output.weight.norm(2)
            regoutb = params["model"]["alpha1"] * model.output.bias.norm(2)
            regfb = params["model"]["alpha1"] * model.feedback.weight.norm(2)
            regfbb = params["model"]["alpha1"] * model.feedback.bias.norm(2)
            regrec = params["model"]["gamma1"] * model.rnn.weight_hh_l0.norm(2)
            reg = regin + regrec + regout + regoutb + regfbb + regfb


            output, hidden, extras = model(
                stimulus[0],
                pert[0],
                stim_ref[0],
            )


            # BACKWARD PASS ############
            loss_train = criterion(output, output * 0).mean()
            toprint["Loss"] = loss_train
            regact = params["model"]["beta1"] * hidden.pow(2).mean()
            reg = regin + regrec + regout + regoutb + regfbb + regfb
            loss = loss_train + reg + regact
            loss.backward(retain_graph=True)

            # CLIP THOSE GRADIENTS TO AVOID EXPLOSIONS ########
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), params["model"]["clipgrad"]
            )

            # APPLY GRADIENTS TO PARAMETERS ########
            optimizer.step()


            train_running_loss = [
                loss_train.detach().item(),
                regact.detach().item(),
                regin.detach().item(),
                regrec.detach().item(),
                regout.detach().item(),
                regoutb.detach().item(),
                regfb.detach().item(),
                regfbb.detach().item(),
            ]
            # printing
            toprint["Loss"] = loss
            toprint["In"] = regin
            toprint["Rec"] = regrec
            toprint["Out"] = regout
            toprint["OutB"] = regoutb
            toprint["Fb"] = regfb
            toprint["FbB"] = regfbb
            toprint["Act"] = regact

            print(
                ("Epoch=%d | " % (epoch))
                + " | ".join("%s=%.4f" % (k, v) for k, v in toprint.items())
            )
            lc.append(train_running_loss)


        print("MODEL TRAINED!")
        with torch.no_grad():
            loss_test = dataset.test(model, lc, params, ph_datname, savname, phase)
            # test_losses.append(loss_test.detach())
        print("MODEL TESTED!")
        # save this phase
        torch.save(
            {
                "epoch": ph_ntrials,
                "model_state_dict": model.state_dict(),
                "model_init_state_dict": init_state,
                "optimizer_state_dict": optimizer.state_dict(),
                "lc": np.array(lc),
                "loss_test_r": test_losses_r,
                "loss_test_r_nofb": test_losses_r_nofb,
                "loss_test_rp": test_losses_rp,
                "params": params,
            },
            savname + "phase" + str(phase) + "_training",
        )




if __name__ == "__main__":
    main()
