#%% SETTING THE PARAMETERS
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import dill

#%% PLOT FEATURES

pal_grey = ['dee2e6', 'ced4da', 'adb5bd', '6c757d', '495057', '343a40']
pal_grey = [f"#{c}" for c in pal_grey]
pal_MOP = '#c55986ff'
pal_greedy = '#354a78ff'
plt.rcParams['text.usetex'] = True
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'

# set default style
sns.reset_defaults() # useful when adjusting style a lot
sns.set_theme(context="paper", style="ticks",
              # palette="Set2",
              palette=pal_grey,
              rc={
              "pdf.fonttype": 42,  # embed font in output
              "svg.fonttype": "none",  # embed font in output
              "figure.facecolor": "white",
              "figure.dpi": 100,
              "axes.facecolor": "None",
              "axes.spines.left": True,
              "axes.spines.bottom": True,
              "axes.spines.right": False,
              "axes.spines.top": False,
          },
          )

sns.color_palette(pal_pink) 

#%%LOAD DATA

name_file = 'Figure6'

with open(name_file + '.pkl', 'rb') as file:
    data = dill.load(file)

params, MSE, STD, STD_single, T_end, Activity_MOP, H_mop, Activity_free, ED_MOP_far, ED_MOP_thresh, rnn, fnn = data


#%%AND OCCUPANCY
N_bins = 50
min_bins = -1
max_bins = 1
delta_bins = (max_bins - min_bins)/N_bins

Occupancy_free = np.zeros((N_bins, rnn.t))
Occupancy_MOP = np.zeros((N_bins, rnn.t))

for ag in range(params['Naverage']):
    for tt in range(rnn.t):
        for bb in range(N_bins):
            for nn in range(rnn.N):
                if min_bins + bb * delta_bins < Activity_free[ag, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_free[bb, tt] += 1
                if min_bins + bb * delta_bins < Activity_MOP[ag, nn, tt] < min_bins + (bb+1) * delta_bins:
                    Occupancy_MOP[bb, tt] += 1



#%% COLORMAP FOR THE ENERGIES

from matplotlib.colors import ListedColormap
from matplotlib.cm import magma, inferno, ocean, turbo, viridis, plasma
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

#Choosing one random trajectory to plot
traj = 1

# Use inferno colormap to map continuous values, max is given by the maximum entropy
# Minimum is given by the minimum entropy of the trajectory I choose to plot
norm = Normalize(vmin=np.min(H_mop[traj,0,:]), vmax=np.log(2**params['Nc']))


fig, axes = plt.subplot_mosaic([
    ["Free","MOP", "no", "T_end", 'STD_single'],
    ["Energy_free", "Energy_MOP", 'colorbar', "pdf", 'pdf']
    ], gridspec_kw=dict(hspace=0.5, 
                      wspace=0.8,
                      width_ratios=[2,2,0.1,0.6,0.6]),)
                        

# axes['Free'].set_title('Free network', fontsize  = 12)
axes['Free'].plot(torch.arange(0, params['TotalT']/params['dt']), Activity_free[traj, :20, :].T, alpha=0.9) 
axes['Free'].set_xlim(0, params['TotalT']/params['dt'])
axes['Free'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Free'].set_yticks([-1,0,1], ['-1','0','1'], fontsize  = 12)
axes['Free'].set_xlabel('t', fontsize  = 12)
axes['Free'].set_ylabel('x(t)', fontsize  = 12) 
axes['Free'].set_ylim(-1,1)

Energy_free = np.linalg.norm(Activity_free[traj, :, :]+1, axis = 0)/params['N']
axes['Energy_free'].plot(torch.arange(0, params['TotalT']/params['dt']), Energy_free, color = pal_grey[3], linewidth = 1.5) 
axes['Energy_free'].plot(torch.arange(0, params['TotalT']/params['dt']), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='#98aaac')
axes['Energy_free'].set_xlim(0, params['TotalT']/params['dt'])
axes['Energy_free'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Energy_free'].set_xlabel('t', fontsize  = 12)
axes['Energy_free'].set_ylabel('E(t)', fontsize  = 12) 
axes['Energy_free'].set_ylim([0.03,0.15])
axes['Energy_free'].set_yticks([0.03, 0.11, 0.15], ['0.03', '0.11', '0.15'], fontsize = 12)

axes['Energy_free'].tick_params(axis='both', which='both', labelsize  = 12)
axes['Free'].tick_params(axis='both', which='both', labelsize  = 12)


# axes['MOP'].set_title('MOP network', fontsize  = 12)
axes['MOP'].plot(torch.arange(0, params['TotalT']/params['dt']), Activity_MOP[traj, :15, :].T, alpha=0.9) 
axes['MOP'].set_xlim(0, params['TotalT']/params['dt'])
axes['MOP'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['MOP'].set_yticks([-1,0,1], ['-1','0','1'], fontsize  = 12)
axes['MOP'].set_xlabel('t', fontsize  = 12)
axes['MOP'].set_ylabel('x(t)', fontsize  = 12) 
axes['MOP'].set_ylim(-1,1)
axes['MOP'].set_yticks([-1,0,1], ['-1','0','1'], fontsize  = 12)


Energy_MOP = np.linalg.norm(Activity_MOP[traj, :, :]+1, axis =0)/params['N']

# axes['Energy_MOP'].plot(torch.arange(0, params['TotalT']/params['dt']), Energy.detach().numpy(), alpha=0.9, color = pal_MOP) 
axes['Energy_MOP'].plot(torch.arange(0, params['TotalT']/params['dt']), rnn.x_th_plus*torch.ones(rnn.t), linestyle='--', linewidth=1.5, color='#98aaac')
axes['Energy_MOP'].set_xlim(0, params['TotalT']/params['dt'])
axes['Energy_MOP'].set_xticks([0,500,1000], ['0','500','1000'], fontsize  = 12)
axes['Energy_MOP'].set_xlabel('t', fontsize  = 12)
axes['Energy_MOP'].set_ylabel('E(t)', fontsize  = 12) 
axes['Energy_MOP'].set_ylim([0.03,0.15])
axes['Energy_MOP'].set_yticks([0.03, 0.11, 0.15], ['0.03', '0.11', '0.15'], fontsize = 12)
colors = magma(norm(H_mop[traj,0])).reshape(-1, 4)
# Draw lines between points with color-dependent colors
for i in range(len(Energy_MOP) - 1):
    axes["Energy_MOP"].plot(range(1000)[i:i+2], Energy_MOP[i:i+2], c=colors[i], linewidth = 1.5)

axes['MOP'].tick_params(axis='both', which='both', labelsize  = 12)
axes['Energy_MOP'].tick_params(axis='both', which='both', labelsize  = 12)

axes['Tend'].plot(np.mean(T_end, 0), alpha=0.9, color = pal_MOP) 
axes['Tend'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(T_end, 0) - np.std(T_end, 0)/np.sqrt(params['Naverage']), y2 = np.mean(T_end, 0) + np.std(T_end, 0)/np.sqrt(params['Naverage']), alpha = 0.2, color = pal_MOP)
axes['Tend'].set_xlabel('epochs', fontsize  = 12)
axes['Tend'].set_ylabel('$t_{end}$', fontsize  = 12) 
axes['Tend'].set_xticks([0, 40], ['0', '40'], fontsize = 12)
axes['Tend'].set_yticks([0, 1000], ['0', '1000'], fontsize = 12)


axes['Tend'].tick_params(axis='both', which='both', labelsize  = 12)
axes['pdf'].tick_params(axis='both', which='both', labelsize  = 12)

axes['pdf'].barh(np.arange(N_bins), np.mean(Occupancy_free, axis=1) / rnn.N, height=1, facecolor = 'none', edgecolor=pal_grey[2], label = 'free')
axes['pdf'].barh(np.arange(N_bins), np.mean(Occupancy_MOP, axis=1) / rnn.N, height=1, facecolor = 'none',  edgecolor=pal_MOP, label = 'MOP')
axes['pdf'].set_yticks([0, 24, 49], ['-1', '0', '1'], fontsize = 12)
axes['pdf'].set_xticks([])
axes['pdf'].set_ylabel('x', fontsize  = 12)
axes['pdf'].set_xlabel('pdf', fontsize  = 12)
axes['pdf'].legend()

axes['STD_single'].plot(np.mean(STD_single,0), alpha=0.9, color = pal_MOP, label = 'MOP') 
axes['STD_single'].fill_between(x = np.arange(params['Nepochs']), y1 = np.mean(STD_single, 0) - np.std(STD_single, 0)/np.sqrt(params['Naverage']), y2 = np.mean(STD_single, 0) + np.std(STD_single, 0)/np.sqrt(params['Naverage']), alpha = 0.2, color = pal_MOP) 
axes['STD_single'].set_xlabel('epochs', fontsize  = 12)
axes['STD_single'].set_xticks([0, 40], ['0', '40'], fontsize = 12)
axes['STD_single'].set_ylabel(r'$\langle\sigma\rangle$', fontsize  = 12) 
axes['STD_single'].tick_params(axis='both', which='both', labelsize  = 12)

axes['no'].tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False)
axes['no'].spines['top'].set_visible(False)
axes['no'].spines['right'].set_visible(False)
axes['no'].spines['bottom'].set_visible(False)
axes['no'].spines['left'].set_visible(False)

# # Adding colorbar using magma
sm = plt.cm.ScalarMappable(cmap=magma, norm=norm)
sm.set_array([])  # empty array for the scalar mappable
cbar = fig.colorbar(sm, axes['colorbar'])
cbar.set_label(r'$\mathcal{H}(\mathcal{A}|x)$', fontsize = 12)
axes['colorbar'].set_yticks([3.94,5.54] ,['<3.94','5.54'], fontsize = 12)
cbar.outline.set_edgecolor('none')

fig = plt.gcf()  # Get the current figure
fig.set_size_inches(15, 5)  # Set the size in inches
plt.show()
