import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

# filepath = "./logs/wandb_export_2024-04-11T13_35_47.891-04_00.csv"
# filepath = "./logs/wandb_export_2024-04-12T13_48_07.817-04_00.csv"
filepath = "./logs/wandb_export_2024-04-12T13_55_58.887-04_00.csv"
df = pd.read_csv(filepath)
df = df[df["State"] == "finished"]
df2 = df[df["struct"] == "none"]
df2["vec"] = df2["struct"]
df = df[df["struct"] == "einsum"]


def concant_and_format(row):
    vars = [f"v{i}" for i in range(7)]
    text = "-".join([f"{str(row[v])}" for v in vars])
    return text


df["vec"] = df.apply(concant_and_format, axis=1)
df = pd.concat((df, df2))
df["ratio"] = df["test_acc"] / df["cola_flops"]
df = df.sort_values(by="train_loss", ascending=True)
# df = df.sort_values(by="ratio", ascending=False)
selected_cols = ["epoch", "test_acc", "train_acc", "train_loss", "vec", "lr", "width", "cola_flops", "ratio"]
df = df[selected_cols]
print(f"There are: {len(df['vec'].unique()):,d} vecs")
exprs = [
    "0.0-0.5-0.5-0.0-0.5-0.0-0.5",
    "0.5-0.5-0.0-0.0-0.5-0.5-0.0",
    "0.44-0.25-0.32-0.0-0.35-0.3-0.35",
    "0.2-0.41-0.4-0.0-0.15-0.39-0.45",
    "0.38-0.29-0.33-0.0-0.04-0.35-0.61",
    "0.08-0.81-0.11-0.0-0.12-0.57-0.31",
    # "0.13-0.35-0.51-0.0-0.69-0.14-0.17",
    "0.11-0.57-0.32-0.0-0.71-0.29-0.0",
    "0.57-0.09-0.34-0.0-0.07-0.65-0.28",
    "0.07-0.41-0.51-0.0-0.63-0.17-0.2"
    "0.0-0.64-0.36-0.0-0.04-0.49-0.47",
    "0.02-0.66-0.33-0.0-0.12-0.61-0.27",
    "0.16-0.7-0.14-0.0-0.11-0.61-0.28",
    "0.07-0.67-0.27-0.0-0.2-0.61-0.19",
    "0.0-0.64-0.36-0.0-0.04-0.49-0.47",
    "0.07-0.67-0.27-0.0-0.2-0.61-0.19",
    "none",
    # "btt",
]
df = df[df["vec"].isin(exprs)]
df = df.loc[df.groupby(["vec", "width"])["test_acc"].idxmax()]
# df = df[df["lr"] == 1.0]
# df = df.sort_values(by="train_loss", ascending=True)
# df = df.sort_values(by="ratio", ascending=False)

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")
plt.figure(dpi=75, figsize=(20, 10))
sns.scatterplot(x="cola_flops", y="test_acc", data=df, style="vec", hue="width", s=200)
# sns.scatterplot(x="cola_flops", y="train_loss", data=df, style="vec", hue="width", s=200)
plt.ylabel("Test Acc")
# plt.ylabel("Train Loss")
plt.xlabel('FLOPs')
plt.xscale('log')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
