from matplotlib import pyplot as plt
from matplotlib.ticker import ScalarFormatter
import pandas as pd
import seaborn as sns

xname = 'cola_flops'
yname = 'test_error'
ylabel = 'Test Error'
ds = 'cifar10'

struct_names = {
    'none': 'Dense',
    'dense': 'Dense',
    'tt': 'TT',
    'btt': 'BTT',
    'btt_norm': 'BTT',
    'kron': 'Kron',
    'monarch': 'Monarch',
    'low_rank': 'Low Rank',
    'conv': 'Conv',
}

runs = pd.read_csv(f"./logs/mlp_{ds}.csv")
runs = runs[runs['struct'].isin(['none', 'kron'])]
runs['struct'] = runs['struct'].apply(lambda x: struct_names[x])
runs['struct'] = runs.apply(
    lambda x: 'Kron w/o shuffle' if x['shuffle_pixels'] is False and x['struct'] == 'Kron' else x['struct'], axis=1)
runs = runs[runs['struct'].isin(['Dense', 'Kron', 'Kron w/o shuffle'])]
runs = runs[[xname, yname, 'struct']]

sns.set(style="whitegrid", font_scale=2.5, rc={"lines.linewidth": 3.0})
hue_order = ['Dense', 'BTT', 'Monarch', 'Low Rank', 'Kron', 'TT', 'Kron w/o shuffle']
pallette = sns.color_palette("Set2", n_colors=len(hue_order))
# plt.figure(dpi=100, figsize=(8, 6))
plt.figure(dpi=100, figsize=(10, 5))
ax = sns.scatterplot(data=runs, x=xname, y=yname, hue='struct', hue_order=hue_order, s=300, palette=pallette)
ax.get_legend().remove()
ax.grid(which='minor', axis='y', linestyle='-', linewidth=0.5)
plt.ylabel(ylabel)
plt.xlabel('FLOPs')
plt.xscale('log')
plt.yscale('log')
plt.gca().yaxis.set_major_formatter(ScalarFormatter())
plt.gca().yaxis.set_minor_formatter(ScalarFormatter())
handles, labels = ax.get_legend_handles_labels()
plt.title(f'MLP on {ds.replace("cifar10", "CIFAR-10")}')
plt.tight_layout()
plt.savefig(f'./figures/shuffle_mlp_{ds}.pdf', bbox_inches='tight')
plt.show()

legend_fig = plt.figure(figsize=(8, 1))
handles = [handles[i] for i in [0, 4, 6]]
labels = [labels[i] for i in [0, 4, 6]]
ax_legend = legend_fig.add_subplot(111)
ax_legend.legend(handles, labels, loc='center', ncol=len(labels))
ax_legend.axis('off')
plt.tight_layout()
plt.savefig(f'./figures/shuffle_mlp_{ds}_legend.pdf', bbox_inches='tight')
