# Introducing the entropy of the state scaled by a factor alpha
# Using the trace defined in numpy to make it faster
# Modify the value of the parameter rescaling the state entropy term (in the code it is called \alpha for an unfortunate choice of the developer)
# Choose as well for plot reproducibility alpha = 0.0, 0.2, 0.5, 0.8, 1.5 and call the respective files 'Figure5_alpha00', 'Figure5_alpha02', 'Figure5_alpha05', 'Figure5_alpha08', 'Figure5_alpha15' 

#%% SETTING THE PARAMETERS

import numpy as np
import math
import torch
import torch.optim as optim
import torch.nn as nn
import dill
from sklearn.decomposition import PCA

name_file = 'Figure5_alpha00' 

params = {  'N' : 100 , # number of neurons in the recurrent neural network 
            'Nc' : 8 , # number of actions components
            #Properties of the simulation
            'alpha' : 0.0, #parameter rescaling state entropy
            'up_threshold' : 0.11, # Upper bound on the energy, 0.10 corresponds to the case where all x[:] = 0
            'gamma_disc' : 0.9, # discount factor for the future intrinsic rewards
            'TotalT' : 30, # Total time of the simulation
            'dt' : 0.05, # discretized time interval integration
            'Naverage' : 10, #Average of the observables
            #Properties of the RNN
            'tau_h' : 1, # timescale of the evolution of the network activity
            'RNNnonlinearity' : torch.tanh , # hyperbolic tangent in the recurrent network
            'gain_parameter' : 5., # gain parameter g
            'J' : 1, # coupling strength
            #Properties of the FNN
            'Ni' : 20, # number of input neurons in the FNN
            'Nh' : 256 , # number of hidden neurons in the FNN
            'No' : 1 , # number of output neurons in the FNN (1 as we only have the value function)
            'activation' : nn.functional.relu, # activation function in the FNN
            'output' : nn.functional.relu, # activation function in the FNN for the output neuron
            'action_scaling' : 1.2, # magnitude of the actions on the RNN, \rho
            'delta_thresh' : 0.0005 , #magnitude where I see the reduction of the PCA components
            'delta_far_thresh' : 0.002 , #magnitude where I see the reduction of the PCA components
            #Training of the FNN
            'Nbatches' : 10, # number of batches in the SGD
            'Nepochs' : 60, # number of training epochs
            'HalfTraining' : 0, # number of epochs after which I want to see progresses of my training
            'learning_rate' : 0.01, # Learning rate of the backprop
            }

#%%CLASS DEFINITION

############################### my RNN class ############################
class myRNN():
    
    # By building my RNN I am assigning this attributes to it
    def __init__(self, params):
        
        #########################################################################################################
        # J [N x N] torch matrix from a standard normal, scaled by g and J/sqrt(N)
        # x [N] torch random vector
        # S [N x Nc] positive random matrix scaled by rho
        #########################################################################################################

        JJ = params['J']/np.sqrt(params['N'])
        J0 = params['gain_parameter'] * JJ * np.random.standard_normal([params['N'], params['N']]).astype(np.float32)
        self.J = torch.from_numpy(J0)
        self.x = torch.rand(params['N'])
        S0 = params['action_scaling'] * np.random.rand(params['N'], params['Nc']).astype(np.float32) 
        self.S = torch.from_numpy(S0)       

        ####################### Variables for the action extraction ############################################
        # Policy vector of 2**Nc elements (number of combinations of actions). Starting from a uniform distribution
        pi0 = np.ones(2**params['Nc'])/(2**params['Nc']) 
        self.pi = torch.from_numpy(pi0)       
        # Cumulative policy vector of 2**Nc elements (number of combinations of actions). Starting from a null distribution
        pi_cumulative0 = np.zeros(2**params['Nc'])
        self.pi_cumulative = torch.from_numpy(pi_cumulative0)       
        # Counters of the number of times one of the 2**Nc action is observed, respectively
        self.f_actions = np.zeros(2**params['Nc'])  # Far from the threshold
        self.f_actions_thresh = np.zeros(2**params['Nc']) # Close to the threshold
        # List of the actions taken to build the PCA, respectively
        self.PCA_actions = [] # Far from the threshold
        self.PCA_actions_thresh = [] # Close to the threshold
   
        ####################### Parameters of the functions in here ############################################
        self.N = params['N'] # Number of neurons in the RNN
        self.alpha = params['alpha'] # Weight of the external reward (state entropy)
        self.gamma = params['gamma_disc'] # Discount factor in the MOP algorithm
        self.Nc = params['Nc'] # Number of action components
        self.x_th_plus = params['up_threshold'] # Threshold, i.e. energy terminal state
        self.tau_h = params['tau_h']  # Intrinsic timescale of the RNN dynamics
        self.dt = params['dt'] # Integration time step of the dynamics
        self.TotalT = params['TotalT'] # Maximum length of the simulation
        self.t = math.floor(self.TotalT/self.dt) # Total maximum length of the simulation in units of dt
        self.activation = params['RNNnonlinearity'] # Activation function of the RNN
        # Setting the variable for what I say it's 
        self.delta_thresh = params['delta_thresh'] # 'close' to the terminal state
        self.delta_far_thresh = params['delta_far_thresh'] # 'far' from the terminal state
         
    ################# The following setters are for the activity ##########

    # I can set the activity cloning a defined value x0 
    def set_initial_activity(self, x0):
        
        self.x = x0.detach().clone()
        
    # I can reset all the activities of the RNN to zero
    def set_zero_initial_activity(self):

        self.x[:] = 0

    # I can set random initial activity (positive and negative) by default between -0.025 and 0.025 but scale can be sent as argument of the function
    def set_random_initial_activity(self, scale = 0.05):

        self.x = (torch.rand(self.N) - 0.5) * scale
        
    ################## The following functions define the evolution of the dynamics ##########
    ################## Dynamics is evolved with Euler integration ############################    
    ################## x' = x + dt * (-x/tau + Phi (Jx + Sa)) ################################

    # This function returns a one-step-ahead of the dynamics given the state of the rnn and the action applied act, i.e. x' ( self.x , act ). 
    # x' is returned as a detached vector and does not modify the self.x, the activity of the RNN itself
    def look_forward(self, act):
        
        x_future = self.x + (self.dt) * ( -self.x/self.tau_h + self.activation( torch.matmul(self.J, self.x) + torch.matmul(self.S, act) ) )
        return x_future

    # This function EVOLVES the dynamics given the state of the rnn and the action applied act, i.e. x' ( self.x , act ). 
    # It modifies the state of the RNN by assigning x_new to be the new state of the network
    def evolve_forward(self, act):
        
        x_new = self.x + (self.dt) * ( -self.x/self.tau_h + self.activation( torch.matmul(self.J, self.x) + torch.matmul(self.S, act) ) )
        self.x = x_new
        
    # This is an auxilary function, not really needed. Performs the same computations as the evolve_forward but without MOP control
    # It is equivalent to the previous one but setting act as a torch.zeros(Nc)
    def evolve_forward_free(self):
        
        x_new = self.x + (self.dt) * ( -self.x/self.tau_h + self.activation( torch.matmul(self.J, self.x) ) )
        self.x = x_new

    ######### The following functions are for the extraction of the actions to be injected in the RNN according to MOP ##########
        
    # I have counters for the actions taken during the trajectories
    # this function reset to zero the counters of the frequencies (f_actions) and empties the lists (PCA_actions)
    def reset_actions_counter(self):
        
        self.f_actions[:] = 0 # the frequency of all actions far from threshold is reset to zero
        self.f_actions_thresh[:] = 0 # the frequency of threshold actions is reset to zero
        self.PCA_actions = [] # the list of all actions far from threshold is emptied
        self.PCA_actions_thresh = [] # the list of all actions close to threshold is emptied

    # Checking if a terminal state has been encountered : by modifying this function different tasks can be implemented
    # Here in the energy task, the function flags a boolean variable as _True_ is the energy of an activity x has gone beyond thresholds
    # Energy is defined as sqrt(sum_i (x+1)**2 )/ x.size, i.e. the normalized L2 norm of the activity translated by 1, as to have the zeor of energy to the -1 of the acitvities
    def check_terminal_state(self, x, delta_th = 0.):
        
        # dead_flag = True if the rescaled L2 norm of ( x + 1 ) is greater than self.x_th_plus - delta_th
        dead_flag = torch.norm( x + 1 , dim = 0 ) / x.size( dim = 0 ) > self.x_th_plus - delta_th # Here the dim = 0 in the norm is not necessary 
        
        return dead_flag
    
    # Introducing the entropy of the states we take the trace of the Jacobian as external reward
    # By favoring the non-saturating regime of the tanh it favors the variability of the states
    def trace_Jacobian(self, x, a):
        
        # Compute the Jacobian matrix analytically as a diagonal matrix
        Jacobian = torch.diag_embed(1 - pow( self.activation( torch.matmul(self.J, x) + torch.matmul(self.S, a) )  , 2 ) )
        # Calculate the trace of the Jacobian matrix
        TraceJ = torch.trace(torch.matmul(self.J, Jacobian))

        return TraceJ
    
    # Function sampling one action out of the matrix all_actions from the policy approximated on the current representation of the value function stored in fnn
    # The function returns as well the partition function Z as the sum_a exp(gamma * V(x,a)) , exploiting the fact I already need all the terms to build the policy
    def extract_action(self, fnn):
        
        # I am storing here as an auxiliary variable the current state I am in (not necessary)
        x_true = self.x.detach().clone()
        
        # Then I scan over all the possible actions I can build, as I have actions of Nc binary components I will have 2**Nc possible action
        # Actions are stored already as being made of -1 and 1
        for act in range(2**self.Nc):

            # As act is the index of the action I will be using, the action in the loop I will be dealing with is fnn.all_actions[act, :]
            # First thing first, I need to look into the future how the system would have evolved subject to this action

            # Given the activity, I look into the future how the system would have evolved subject to this action, i.e. x'(x, a) -- deterministic dynamics --
            x_future = self.look_forward(fnn.all_actions[act, :]) 
            
            # Check if the state I am ending in is a terminal state or not
            dead_flag = self.check_terminal_state(x_future)

            # If I am ending in a terminal state, its value is gonna be zero, i.e. if x' > x_th then V(x') = 0
            if  dead_flag:
                Value = torch.tensor(0.) # Assigning it as a torch.tensor for compatibility with what is returned by the FNN approximator

            # If I am not in a terminal state, I approximate the Value using the representation of weights and biases of the FFN I have
            else:
                Value = fnn.Value_approx(x_future).detach() # Value is detached from the computation to make sure does not enter in the SGD (not needed, remove it)
                            
            # Compute the Jacobian matrix analytically as a diagonal matrix
            Jacobian = torch.diag_embed(1 - pow( self.activation( torch.matmul(self.J, self.x) + torch.matmul(self.S, fnn.all_actions[act, :]) )  , 2 ) )
            # Calculate the trace of the Jacobian matrix
            TraceJ = torch.trace(torch.matmul(self.J, Jacobian))
            
            # Given the representation of the value function, I can compute the policy as pi(a|x) = exp(gamma * V(x'(x,a)))
            # The Jacobian is computed on the function in the state the nework is and the action taken a
            
            self.pi[act] = torch.exp(self.alpha * TraceJ.detach() + self.gamma * Value) 
        
        # The partition function is the sum of the elements in the policy to normalize it
        # As I am scanning all the elements in the policy, I don't really need to initialize pi to uniform
        # In theory I could also subsample the actions in all_actions and build an approximation for the policy ( Convergence gets slower though )
        Z = torch.sum( self.pi )   
        # Then I can normalize the policy by this partition function
        self.pi = torch.divide(self.pi, Z)
        
        # Building the cumulative: the first element/action corresponds to pi, then I am summing what there was in the previous element plus the policy in that index        
        self.pi_cumulative[0] = self.pi[0] 
        # Scanning all the other actions and storing the cumulative in pi_cumulative
        for i in range(1,2**self.Nc):
            self.pi_cumulative[i] = self.pi_cumulative[i-1] + self.pi[i]  #Every step I am adding the policy pi(a|x)   
           
        # Given the cumulative of pi, I extract the action from pi by sampling a random number and checking the corresponding index in pi_cumulative
        r = np.random.rand()
        ii = 0
        
        # Scanning sequentially with index ii until r is not smaller than the probability, that's the action I will be extracting
        # the final ii is the index of the corresponding action I have sampled
        while self.pi_cumulative[ii] <= r and ii < 2**self.Nc - 1: # This ii < 2**self.Nc - 1 condition is only to avoid the numerical overflow
            ii += 1

        ##### Here we store the sampled action differently if it is extracted from a state that is close or far from the terminal state #####
        
        # First we see if the state I am currently in is close to the terminal state of an interval self.delta_thresh
        # If yes, the flag is True and I:
        flag = self.check_terminal_state(x_true, self.delta_thresh)
        if flag:
            self.f_actions_thresh[ii] = self.f_actions_thresh[ii] + 1 # increase the counter of the action corresponding to the sampled one 
            self.PCA_actions_thresh.append(fnn.all_actions[ii,:]) # append to the list the sampled action
        
        # First we see if the state I am currently in is close to the terminal state of an interval self.delta_far_thresh
        flag = self.check_terminal_state(x_true, self.delta_far_thresh)
        # If I am very far from the threshold the flag will be false, so I:
        if flag == False:
            self.f_actions[ii] = self.f_actions[ii] + 1 # increase the counter of the action corresponding to the sampled one 
            self.PCA_actions.append(fnn.all_actions[ii,:]) # append to the list the sampled action

        # As the ii I get is the index of the action I am sampling in the end, I will be returning the corresponding action stored fnn.all_actions[ii,:] as well as this partition function
        return fnn.all_actions[ii,:], Z
        
    ############ Auxiliary functions to run easily trajectories to see the dynamics ############

    # Trajectores are run sampling actions extracted with MOP selecting the initial condition of the trajectory
    def run_example_traj(self, fnn, x0):
        
        Activity = torch.zeros([self.N, self.t]) # [N x t] matrix of N neurons activities in the t total time steps (units of dt)
        self.x = x0.clone() #Always start with the same initial condition

        #Evolving in time the dynamics from 0 to t maximum length step 
        for tt in range(self.t):
            # Storing the activity exerienced by the dynamics
            Activity[:,tt] = self.x 
            # Check if I ended up in a terminal state
            dead_flag = self.check_terminal_state(self.x)
            if dead_flag :
                # If a terminal state has been encountered, print the value of the energy to keep track of errors and END the dynamics
                print('Energy being = ', torch.norm( self.x + 1 , dim = 0 ) / self.x.size( dim = 0 ))
                break
            # If a terminal state has not been encountered and the dynamics is not stopped, extract an action frmo a MOP policy
            a, _ = self.extract_action(fnn) # I don't care here of what is the partition function
            # And evolve the dynamics using this action
            self.evolve_forward(a)
            
        return Activity

    # Trajectores are run without external control and selecting the initial condition of the trajectory
    def run_example_traj_free(self, x0):
        
        Activity_free = torch.zeros([self.N, self.t]) # [N x t] matrix of N neurons activities in the t total time steps (units of dt)
        self.x = x0.clone() #Always start with the same initial condition

        #Evolving in time the dynamics from 0 to t maximum length step 
        for tt in range(self.t):
            # Storing the activity exerienced by the dynamics
            Activity_free[:,tt] = self.x
            # And evolve the dynamics using NO actions
            self.evolve_forward_free()
        
        return Activity_free
    
################################# My FeedForward Network Class ############################
class myFFN(): 
    
    # By building my FFN I am assigning this attributes to it
    def __init__(self, params):

        ########### STRUCTURE OF MY FFN ############################
        ########### First layer : Ni nodes + 1bias node ############
        ########### Hidden layer : Nh nodes + 1bias node ###########
        ########### Output layer : No = 1 node #####################
    
        # Starting from very small positive values for the weights
        InpHid0 = 1/((params['Ni']) * params['Nh']) * np.random.rand(params['Nh'], params['Ni']).astype(np.float32) # Matrix of weights the Ni nodes to the Nh nodes in the hidden layer
        HidOut0 = 1/(params['Nh'] * params['No']) * np.random.rand(params['No'], params['Nh']).astype(np.float32) # Matrix of weights the Nh nodes to the No node in the output layer
        # Starting from zero bias
        bInp0 = np.zeros(params['Nh']).astype(np.float32) # Weights for the single bias node (activity fixed at 1) in the input layer to the Nh nodes in the hidden layer
        bHid0 = np.zeros(params['No']).astype(np.float32) # Weights for the single bias node (activity fixed at 1) in the hidden layer to the No node in the output layer
      
        # Using Pytorch to transform my arrays into tensors : these are TENSOR VARIABLES that REQUIRE GRAD
        # This means that Pytorch is building the graph and is going to compute the gradient and backpropagate it through them
        self.InpHid = torch.autograd.Variable(torch.from_numpy(InpHid0),requires_grad=True) 
        self.HidOut = torch.autograd.Variable(torch.from_numpy(HidOut0),requires_grad=True)  
        self.bInp = torch.autograd.Variable(torch.from_numpy(bInp0),requires_grad=True) 
        self.bHid = torch.autograd.Variable(torch.from_numpy(bHid0),requires_grad=True) 

        ####################### Parameters of the functions in here ############################################
        self.activation = params['activation'] # activation functions in the hidden layers of the FFN
        self.output = params['output'] # activation functions in the output node of the FFN in case we would need it differently
        self.Nc = params['Nc'] # number of action components 
        self.Ni = params['Ni'] # number of input nodes, corresponding to the number of activities read from the RNN (any additional info for the value function require extentions here)
        self.Nh = params['Nh'] # number of hidden nodes
        self.No = params['No'] # number of output nodes (single one in the case of the value function)
        self.gamma = params['gamma_disc'] # discount factor in the value approimator
        #Matrix of all possible actions: there are 2**self.Nc vectors stored of self.Nc binary components
        self.all_actions = torch.zeros(2**params['Nc'],params['Nc'])
        
    # Building the matrix where I store all the possible action combinations:
    def initialize_all_actions(self):

        # There are 2**Nc possible actions of Nc binary components 
        for act in range(2**self.Nc):
            #Exploiting the floor function to build actions that are made of 0 and 1
            act_accum = 0
            # First setting the zero component of the action
            self.all_actions[act, 0] = math.floor( act / (2**(self.Nc-1)) )
            # Setting all the other N_c-1 component of my action vector (from 1 to Nc)
            for act_component in range(1, self.Nc):
                act_accum += self.all_actions[act, act_component-1] * (2**(self.Nc-act_component))
                self.all_actions[act, act_component] = math.floor((act-act_accum)/(2**(self.Nc-act_component-1)))
        
        #I want actions to be -1 and 1, so I make the transformation on all the elements
        self.all_actions = 2 * self.all_actions - 1
        
    # Function computing the output of the feedforward network
    def Value_approx(self, x0):
        
        # The value approximator builds the value function on a subset of the units. 
        # In a fully connected network it doesn't matter what are the neurons the approximation is built on
        # So we store in a list the indices of the units we will be removing from the activity 
        mute_neurons = []
        # Filling the list: trivially removing the first len(x) - Ni neurons
        for i in range(len(x0) - self.Ni):
            mute_neurons.append(i)
        # Creating the array of the activities 
        x_i = torch.zeros(self.Ni)
        # And storing in it the activity x0 from which I remove the indices contained in mute_neurons
        x_i[0:self.Ni] = np.delete(x0,mute_neurons) # Correctly returns a tensor

        # And I build the Value as the value of the output node once the x_i vector passes through the feedforward network
        Value = self.output( torch.matmul( self.HidOut, self.activation( torch.matmul( self.InpHid,x_i ) + self.bInp ) ) + self.bHid )

        # The max possible entropy comes from the unifrom distribution : setting a cutoff whether the netwrrk is overshooting this value
        if Value > 1./(1. - self.gamma) * np.log(2**self.Nc):
            Value = torch.tensor( 1./ ( 1. - self.gamma ) * np.log( 2**self.Nc ))
            
        # And returning this value
        return Value
    
#### Class with the simulator: training of the FFN, could be incorporated in the myFFN class ####
class Simulator():
    
    def __init__(self, params):
        
        # Parameters of the simulation
        self.EPOCHS = params['Nepochs'] # Number of training epochs
        self.HalfTraining = params['HalfTraining'] # Number of intermediate training epochs
        self.BATCHES = params['Nbatches'] # Number of trajectories in each batch
        self.lr = params['learning_rate'] # Learning rate of the gradient descent
        
    # Function creating the rnn and the ffn given the parameters
    def start_simulation(self, params):
        
        rnn = myRNN(params) # Initializer of the rnn
        fnn = myFFN(params) # Initializer of the ffn 
        fnn.initialize_all_actions() # Initalizer of the list of all possible actions
        
        # Returning the created rnn and ffn
        return rnn, fnn
    
    # Function building the Cost function as C(theta) = sum_{batches} \sum_{t_end} W*(V(x) - lnZ(x))^2
    # V* = lnZ is the solution of the Bellman recursive equation so we want to minimize the distance between the approximated value and its Bellman evolution
    # W is a parameter assigning more weight to the terminal state
    def Cost_function(self, rnn, fnn):
        
        # Initializing the vectors where we will be storing values
        y = torch.zeros([self.BATCHES, rnn.t]) # Matrix of the Value V(x(tau,t)) taken at each time step of the trajectories in the batch
        y_hat = torch.zeros([self.BATCHES, rnn.t]) # Matrix of the Partition function lnZ(x(tau,t)) taken at each time step of the trajectories in the batch
        W = torch.zeros([self.BATCHES, rnn.t]) # Matrix of the Parameter W(x(tau,t)) taken at each time step of the trajectories in the batch
        Variability = torch.zeros(self.BATCHES) # Vector storing the standard deviation of the activity over the trajectories
        Variability_single = torch.zeros(self.BATCHES) # Vector storing the mean standard deviation of the activity over the trajectories
        Variability_delta = torch.zeros(self.BATCHES) # Vector storing the mean standard deviation of the jumps over two consecutives activity
        t_end = torch.zeros(self.BATCHES) # Vector storing the time of end of the trajectories
        x0 = (torch.rand(rnn.N) - 0.5) * 0.01 # Random initial activity between -0.005 and 0.005 
        
        # Loop over the trajectory in the batch
        for bp in range(self.BATCHES):

            # Each trajectory in the batch starts within the same initial condition
            rnn.set_initial_activity(x0)
            # Initializing the Activity vector in each trajectory to zero
            Activity = torch.zeros([rnn.N, rnn.t])
            # We start with optimism : we expect the network to survive the whole length of the simulation
            # If a terminal state is encountered before this value will be updated later
            t_end[bp] = rnn.t

            for tt in range(rnn.t):
                
                # Storing the state of the rnn in the activity vector
                Activity[:, tt] = rnn.x
                # Approximating there the value function with the ffn
                y[bp, tt] = fnn.Value_approx(rnn.x) 
                # Checking if the state that RNN is in is a terminal state
                dead_flag = rnn.check_terminal_state(rnn.x)
        
                # If the rnn is in a terminal state
                if dead_flag:
                    # the optimal value would be zero as it would be made of one single possible action, i.e. {doing nothing}
                    y_hat[bp, tt] = 0.
                    # In the cost function this discrepancy between V and logZ is weighted ten times more
                    W[bp, tt] = 1.
                    # Store this timestep as the end of this trajectory 
                    t_end[bp] = tt
                    # And end the simulation of the trajectory, this break breaks the tt loop
                    break
                
                # Otherwise extract an action a from the policy pi
                a, Z = rnn.extract_action(fnn)
                # Store the logaritm of the partition function
                y_hat[bp, tt] = torch.log(Z).detach() # in torch log is in natural base 
                # I will weight this contribution in the cost function less
                W[bp, tt] = 0.1
                # And finally evolve the dynamics forward
                rnn.evolve_forward(a)

            # At the end of each trajectory, I compute the variability that the network has experienced in this trajectory
            # Variability here is computed as the standard deviation overall the units and computed only in the trajectory before the terminal state is encountered
            # Activity matrix is zero for all the elements after t_end, so I am excluding those values from the computation 
            Variability[bp] = torch.std(Activity[Activity.nonzero(as_tuple = True)]) 
            
            # Here I compute the variability as the mean of the variability single neurons show
            # I take the standard deviation only on the values that are not zero (before t_end)
            # Note that when you slice a vector you are excluding the last element, hence the +1 
            # Axis is 1 as I am taking the std in time, then averaging across neurons
            Variability_single[bp] = torch.mean(torch.std(Activity[:,:int(t_end[bp].item())+1], axis = 1))

            # Here I compute the variability as the mean of the variability of the 'jumps' between two consecutives activities
            jumps = Activity[:,:int(t_end[bp].item())] - Activity[:,1:int(t_end[bp].item())+1]
            Variability_delta[bp] = torch.mean(torch.std(jumps, axis = 1))

        # Finally building the cost function : (y - y_hat) as well as W is a [BATCHES x T] vector, so we sum over the dimension 1 to sum over time
        # By initializing y and y_hat to zero I don't have to take care of what happens after t_end as it will not contribute to the sum and I normalize by dividing by t_end
        # We get a vector of [[0],...[0]], we use the function view to transform it in a single [BATCHES] tensor, in the same shape as t_end
        # torch.div then performs an elment by element division and returns a single [BATCHES] tensor
        # The cost is then averaged by summing all the contributions coming from all the trajectories and dividing by the number of batches
        Cost = torch.sum(torch.div(torch.sum(W*torch.pow(y - y_hat, 2), 1, keepdim = True).view(self.BATCHES), t_end)) / self.BATCHES 

        # This function returns the cost, the mean standard deviation, the mean std of single neurons and the mean survival time for this epoch
        return Cost, Variability, Variability_single, Variability_delta, torch.mean(t_end)

    # Function that trains the weights of the feedfoward function 
    def train_W(self, rnn, fnn, x0):
        
        # Choosing the optimizer : we choose Adam with standard betas, define our own learning rate and choose the weights to train (variables of the ffn)
        opt = optim.Adam([fnn.InpHid, fnn.HidOut, fnn.bInp, fnn.bHid], lr = self.lr)

        # Initializing the variables we want things to be stored in
        MSE = np.zeros([self.EPOCHS]) # Accumulating in here the mean squared error over epochs
        STD = np.zeros([self.EPOCHS]) # Accumulating in here the standard deviation over epochs
        STD_single = np.zeros([self.EPOCHS]) # Accumulating in here the standard deviation of single neurons over epochs
        T_end = np.zeros([self.EPOCHS]) # Accumulating in here the mean length of trajectories over epochs        
        V_0 = np.zeros([self.EPOCHS]) # Accumulating in here the value computed in zero over epochs
          
        # Looping over self.EPOCHS 
        for ep in range(self.EPOCHS):

            # In each epoch I reset the counter of the actions that are taken (in case this is information I want to use differently across epochs)
            rnn.reset_actions_counter()
            # Calling the cost function accumularing the Bellman error across trajectories and time 
            Cost, Variability, Variability_single, _, tt = self.Cost_function(rnn, fnn)

            #Storing the values at each epoch to monitor the learning process
            MSE[ep] = Cost # Mean square error in each epoch
            STD[ep] = torch.mean(Variability) # Standard Deviation in each epoch
            STD_single[ep] = torch.mean(Variability_single) # Standard Deviation in each epoch
            T_end[ep] = tt # Average survival time of the epoch
            
            if ep == self.HalfTraining:
                Activity_train = rnn.run_example_traj(fnn, x0)

            # Given the cost function, I call the backward function to compute the gradient 
            # This function computes dCost/dw for every parameter w which has requires_grad=True. 
            Cost.backward() 
            # The step function evolve the parameters defined in the optimizers of a w.grad that is accumulated in the backward function above
            opt.step()
            # After calling the step function, gradient needs to be reset to zero so that in the next epoch step I can go back in accumulating the gradient
            opt.zero_grad()

            # Finally I want to monitor the evolution of the value function, here for instance I store and analyze what happens to the value function in zero
            rnn.set_zero_initial_activity()
            V_0[ep] = fnn.Value_approx(rnn.x)
            # This printing helps me monitoring the evolution of the training from the terminal
            print('At ep ', ep, ' of the Opt the MSE is', Cost.detach().numpy(), 'and time of end is', tt.detach().numpy(), 'and Value ', V_0[ep])

        # The weights have been constantly updated along the epochs, hence I return the learning evolution of the significant observables
        return MSE, STD, STD_single, T_end, Activity_train
    
#%% STARTING THE SIMULATION CREATING THE ENVIRONMENT (rnn) AND THE AGENT (fnn)            

#Setting a fixed seed for both numpy and PyTorch RNG
np.random.seed(100)
torch.manual_seed(100)

#%%DEFINING QUANTITIES
### Learning quantities will be [Naverage X EPOCHS] as I simulate Naverages different networks and study the learning function in these NAveragees cases
MSE = np.zeros([params['Naverage'], params['Nepochs']]) # Accumulating in here the mean squared error over epochs
STD = np.zeros([params['Naverage'], params['Nepochs']]) # Accumulating in here the standard deviation over epochs
STD_single = np.zeros([params['Naverage'], params['Nepochs']]) # Accumulating in here the standard deviation of single neurons over epochs
Var = np.zeros([params['Naverage'], 50]) # Accumulating in here the standard deviation over epochs
Var_single = np.zeros([params['Naverage'], 50]) # Accumulating in here the standard deviation of single neurons over epochs
Var_delta = np.zeros([params['Naverage'], 50]) # Accumulating in here the standard deviation of the acitivty jumps (DeltaX) over epochs
T_end = np.zeros([params['Naverage'], params['Nepochs']]) # Accumulating in here the final time over epochs
ED = np.zeros([params['Naverage'], 1]) # Accumulating in here the final Effective Dimensionality of the actions
ED_thresh = np.zeros([params['Naverage'], 1]) # Accumulating in here the Effective Dimensionality close to the threshold
ED_s = np.zeros([params['Naverage'], 1]) # Accumulating in here the final Effective Dimensionality of the actions
ED_s_free = np.zeros([params['Naverage'], 1]) # Accumulating in here the Effective Dimensionality close to the threshold
Activity = np.zeros([params['Naverage'], params['N'], math.floor(params['TotalT']/params['dt'])]) # Accumulating in here the Effective Dimensionality close to the threshold
Activity_free = np.zeros([params['Naverage'], params['N'], math.floor(params['TotalT']/params['dt'])])  # Accumulating in here the Effective Dimensionality close to the threshold
Energy = np.zeros([params['Naverage'], 1, math.floor(params['TotalT']/params['dt'])]) # Accumulating in here the Effective Dimensionality close to the threshold
Energy_free = np.zeros([params['Naverage'], 1, math.floor(params['TotalT']/params['dt'])]) # Accumulating in here the Effective Dimensionality close to the threshold

#%% 
############ LOOP OVER Naverage AGENTS THAT I AM SIMULATING #############

for i in range(params['Naverage']):

    print('Random extractions of the network ', i)
    # Starting the simulation: in each loop I build again a new rnn and a new fnn, with their new random iniatilizing of the weights
    sim = Simulator(params)   
    rnn, fnn = sim.start_simulation(params)  
    
    #Control will start with the same initial activity, here starting between -0.05 and 0.05
    x0 = (torch.rand(rnn.N) - 0.5) * 0.1

    # Given the networks, I train the weights of the feedforward network to approximate the value function for the specific rnn
    MSE[i, :], STD[i,:], STD_single[i,:], T_end[i,:], Activity_train = sim.train_W(rnn, fnn, x0)

    # After training I may want to use the set of actions experienced in the last epoch
    # Otherwise reset the counter of the actions taken to study what are the actions taken in a single trajectory
    rnn.reset_actions_counter()

    # I use this initial activity to run an example of trajectory in the MOP and free case 
    rnn.set_initial_activity(x0)
    traj = rnn.run_example_traj(fnn, rnn.x)
    Activity[i, :, :] = traj
    rnn.set_initial_activity(x0)
    traj_free = rnn.run_example_traj_free(rnn.x)
    Activity_free[i,:,:] = traj_free
    
    # Create a PCA instance
    pca_s_free = PCA()
    # Fit the data to the new PCA instance
    pca_s_free.fit(traj_free.T)
    # Access the eigenvalues and compute effective dimensionality 
    # Effective dimensionality is (sum lambda)**2 / sum (lambda**2) with lambda eigenvalue
    # Equivalent to take the 1 / sum ( normal_lambda ** 2) with norm_lambda normalized eigenvalue
    ED_s_free[i,0] = (np.sum(pca_s_free.explained_variance_)**2)/np.sum(pca_s_free.explained_variance_**2)
    
    pca_s = PCA()
    # Fit the data to the new PCA instance
    pca_s.fit(traj.T)
    # Access the eigenvalues and compute effective dimensionality 
    # Effective dimensionality is (sum lambda)**2 / sum (lambda**2) with lambda eigenvalue
    # Equivalent to take the 1 / sum ( normal_lambda ** 2) with norm_lambda normalized eigenvalue
    ED_s[i,0] = (np.sum(pca_s.explained_variance_)**2)/np.sum(pca_s.explained_variance_**2)

    #Computing the energies corresponding to the activities I have run
    Energy[i, 0, :] = torch.norm( ( traj + 1 ), dim = 0 ) / traj.size( dim = 0 )
    Energy_train = torch.norm( ( Activity_train + 1 ), dim = 0 ) / Activity_train.size( dim = 0 )
    Energy_free[i, 0, :] = torch.norm( ( traj_free + 1 ), dim = 0 ) / traj_free.size( dim = 0 )

    #%% ANALYSIS OF THE PCAs        
    # Convert the list of NumPy arrays into a single NumPy array
    X = np.array([tensor.numpy() for tensor in rnn.PCA_actions])
    # Create a PCA instance
    pca = PCA()
    # Fit the data to the new PCA instance
    pca.fit(X)
    # Access the eigenvalues and compute effective dimensionality 
    eigenval = pca.explained_variance_
    # Effective dimensionality is (sum lambda)**2 / sum (lambda**2) with lambda eigenvalue
    # Equivalent to take the 1 / sum ( normal_lambda ** 2) with norm_lambda normalized eigenvalue
    ED[i,0] = (np.sum(eigenval)**2)/np.sum(eigenval**2)

    X_thresh = np.array([tensor.numpy() for tensor in rnn.PCA_actions_thresh])
    if X_thresh.shape[0] > 0 :
        pca_thresh = PCA()
        pca_thresh.fit(X_thresh)
        eigenval_thresh = pca_thresh.explained_variance_
        ED_thresh[i,0] = (np.sum(eigenval_thresh)**2)/np.sum(eigenval_thresh**2)
    else:
        ED_thresh[i,0] = 0

    # WITH TRAINED DATA, I GENERATE MORE TRAJECTORIES TO SEE THE FINAL STANDARD DEVIATIONS
    # How many trajectories do I want to generate
    sim.BATCHES = 50
    # Basically in the cost function I am generating BATCHES trajectories so I run the function 
    _, Var[i, :], Var_single[i, :], Var_delta[i,:],_ = sim.Cost_function(rnn, fnn)

#%% SAVING THE DATA, collecting in a list all the values I want to store, then dumping it with dill

data = [params, MSE, STD, STD_single, Var, Var_single, Var_delta, T_end, ED, ED_thresh, ED_s, ED_s_free, rnn, fnn]

with open(name_file + '_training.pkl', 'wb') as file:
    dill.dump(data, file)

data = [Activity, Activity_free, Energy, Energy_free]

with open(name_file + '_activity.pkl', 'wb') as file:
    dill.dump(data, file)

