#%%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl


plt.rcParams['axes.labelsize'] = 25
plt.rcParams['axes.titlesize'] = 25
plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['legend.fontsize'] = 25

markers = {"QRDQN" : "s", "p-DLTV": "D", "PQR":"X", "DLTV":"o"}
palette = {"QRDQN" : "tab:orange","p-DLTV": "tab:green", "PQR":"tab:red", "DLTV":"tab:blue"}
hue_order = ['QRDQN', 'DLTV', 'p-DLTV', 'PQR']
DIV_LINE_WIDTH = 50


data = pd.read_excel('plot_sensitivity.xlsx', engine='openpyxl')
data = pd.DataFrame(data)
data['Timestep'][2] = 5000
data = data[data['hyperparameter'] > 0 ]
data = data.loc[:,'QRDQN':'PQR.16']
print(data)

mul_index = pd.MultiIndex.from_product([[1,5,10,50,100,500,1000,5000], 
['5K', '10K', '15K', '20K'], ['QRDQN','DLTV','p-DLTV','PQR'] ], names=['Hyperparameter', 'Timesteps', 'Algorithm'] )
df_new = pd.DataFrame(data = data.to_numpy().reshape(-1,17), index=mul_index, columns= 
    [('Mean',1), ('Mean',2), ('Mean',3), ('Mean',4), ('Mean',5), 
    ('Std',1), ('Std',2), ('Std',3), ('Std',4), ('Std',5),  
    ('Error',1), ('Error',2), ('Error',3), ('Error',4), ('Error',5),
    ('Error', 'Mean'), ('Error', 'Std')
    ])
df_new.columns = pd.MultiIndex.from_tuples(df_new.columns, names=['Error', 'Seed'])

pd.set_option('display.max_rows', None)

df_new = df_new.iloc[:, 10:15] # Error
df_new.columns = [1,2,3,4,5]
df_new.columns.name = 'Seed'


hyp = df_new.index.levels[0]


import seaborn as sns


fig, axes = plt.subplots(1, len(hyp) , figsize =(50,5), sharey= True)

i=0
for c in hyp:
    df_new_hyp = df_new.loc[c].stack().reset_index()
    df_new_hyp.rename(columns = {0 : 'Error'}, inplace=True)
    ax = sns.lineplot(data = df_new_hyp, x= 'Timesteps', y='Error',  hue='Algorithm', 
        style='Algorithm', err_style="band", errorbar = ('ci', 95) , ax=axes[i],
        markers=markers, palette=palette, hue_order=hue_order, 
        linewidth=2.5, dashes=False, markersize=15,  markeredgewidth=0)
    
    
    ax.set_title('Hyperparameter = ' + str(c))
    ax.grid(color='grey', linestyle=':', linewidth=1)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[0:], labels=labels[0:])
    if i !=0:
        ax.legend([],[], frameon=False)
    i +=1
plt.subplots_adjust(wspace= 0.1)






# %%
