import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set1")

df = pd.read_csv("./logs/wandb_export_2023-11-16T21_43_28.846-05_00.csv")
df = df[df["State"] == "finished"]

plt.figure(dpi=100, figsize=(10, 8))
for case in (1, 2, 8, 32):
    dff = df[df["kron_mult"] == case]
    plt.scatter(dff["cola_flops"], dff["test_acc"], label=f"{case}")
plt.xlabel("FLOPs")
plt.ylabel("Test Accuracy")
# plt.ylim([90, 100.5])
plt.xscale("log")
plt.legend()
plt.tight_layout()
plt.show()
