import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")

df = pd.read_csv("./logs/wandb_export_2023-12-20T19_42_49.514-05_00.csv")
mask = (df["State"] == "finished")
df = df[mask]
var_name = "width"
var_s = list(np.sort(df[var_name].unique()))

plt.figure(dpi=100, figsize=(14, 8))
for var in var_s:
    dff = df[df[var_name] == var]
    plt.scatter(dff["cola_flops"], dff["test_acc"], label=var)
plt.xlabel("FLOPs")
plt.ylabel("Test Accuracy")
plt.ylim([50, 85])
plt.xscale("log")
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
