import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

# filepath = "./logs/wandb_export_2024-03-29T13_04_03.983-04_00.csv"
filepath = "./logs/wandb_export_2024-03-29T14_25_27.391-04_00.csv"
df = pd.read_csv(filepath)
dff = df[df["struct"] == "ein_expr"]

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")
plt.figure(dpi=100, figsize=(20, 10))
# sns.scatterplot(x="cola_flops", y="test_acc", data=dff, style="expr0", hue="width", s=200)
sns.scatterplot(x="cola_flops", y="test_acc", data=dff, hue="expr0", s=200)
sns.lineplot(x="cola_flops", y="test_acc", data=dff, hue="expr0")
plt.ylabel("Test Acc")
plt.xlabel('FLOPs')
plt.xscale('log')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
