from matplotlib import pyplot as plt
from matplotlib.ticker import ScalarFormatter
import pandas as pd
import seaborn as sns

xname = 'cola_flops'
yname = 'test_error'
ylabel = 'Test Error'
ds = 'cifar10'

struct_names = {
    'none': 'Dense',
    'dense': 'Dense',
    'tt': 'TT',
    'btt': 'BTT',
    'btt_norm': 'BTT',
    'kron': 'Kron',
    'monarch': 'Monarch',
    'low_rank': 'Low Rank',
    'conv': 'Conv',
}

# runs = pd.read_csv(f"./logs/mlp_{ds}.csv")
runs = pd.read_csv("./logs/aware_scale.csv")
runs = runs[runs['struct'].isin(['none', 'btt', 'kron', 'tt', 'monarch', 'low_rank', 'conv'])]
runs['struct'] = runs['struct'].apply(lambda x: struct_names[x])
runs['use_wrong_mult'] = runs['use_wrong_mult'].apply(lambda x: 'Naive' if x else 'Struct-Aware')
runs = runs[[xname, yname, "struct", "use_wrong_mult"]]

marker_styles = {'Naive': 'v', 'Struct-Aware': 'o'}
sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
hue_order = ['Dense', 'BTT', 'Monarch', 'Low Rank', 'Kron', 'TT']
pallette = sns.color_palette("Set2", n_colors=len(hue_order))

plt.figure(dpi=100, figsize=(6, 6))
ax = sns.scatterplot(data=runs, x=xname, y=yname, hue='struct', hue_order=hue_order, s=200, palette=pallette,
                     style='use_wrong_mult', markers=marker_styles)
plt.ylabel(ylabel)
plt.xlabel('FLOPs')
plt.xscale('log')
plt.yscale('log')
plt.gca().yaxis.set_major_formatter(ScalarFormatter())
plt.gca().yaxis.set_minor_formatter(ScalarFormatter())
ax.grid(which='minor', axis='y', linestyle='-', linewidth=0.5)

style_handles, style_labels = ax.get_legend_handles_labels()
ax.legend(style_handles[-2:], style_labels[-2:], loc='lower left', fontsize='small', handleheight=0.5, handlelength=1,
          handletextpad=0.02, borderaxespad=0.02)

plt.title(f'MLP on {ds.replace("cifar10", "CIFAR-10")}')
plt.tight_layout()
plt.savefig(f'./figures/naive_mlp_{ds}_full.pdf', bbox_inches='tight')
plt.show()
