import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

# filepath = "./logs/wandb_export_2024-03-22T19_00_09.752-04_00.csv"
# filepath = "./logs/wandb_export_2024-03-27T11_11_45.987-04_00.csv"
filepath = "./logs/wandb_export_2024-04-03T13_00_28.755-04_00.csv"
df = pd.read_csv(filepath)
dff = df
# dff = df[df["fact"] == "tree"]
# dff = df[df["v0"].isna()]
# mask = df["State"] != "failed"
dff["expr0"] = dff["expr0"].str[:-3]
dff["ratio"] = dff["test_acc"] / dff["cola_flops"]
agg_fun = {"ratio": "mean", "test_acc": "max"}
aux = dff.groupby(by="expr0").agg(agg_fun).reset_index()
aux = aux.sort_values(by="ratio", ascending=False)
aux = aux.sort_values(by="test_acc", ascending=False)
uniq_expr = list(dff["expr0"].unique())
exprs = ["ef,bcf,ace->ab", "ef,bf,ae->ab", "fg,cdfg,acdg->ad", "fg,dfg,adg->ad"]
exprs += ["eg,bdg,adeg->abd", "efg,dfg,adeg->ad", "e,bd,ade->abd", "e,bd,ade->abd"]
# exprs += ["fg,bfg,g->b", "eg,g,aeg->a", "f,bdf,d->bd", "g,bg,g->b", "e,d,ade->ad"]
# exprs += ["eg,bdg,adeg->abd", "fg,dfg,adg->ad", "eg,bg,eg->b", "fg,bdfg,adg->abd", "e,bd,ade->abd"]
dff = dff[dff["expr0"].isin(exprs)]
aux_expr = "f,bdf,d->bd"
aux = dff[dff["expr0"] == aux_expr]
aux = aux[["test_acc", "v1", "v5"]]
print(aux)
markers = [
    "o", "v", "^", "<", ">", "1", "2", "3", "4", "8", "s", "p", "P", "*", "h", "H", "+", "x", "X", "D", "d", "|", "_", 0, 1, 2, 3,
    4, 5, 6, 7, 8, 9, 10, 11
]

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")
plt.figure(dpi=100, figsize=(20, 10))
sns.scatterplot(x="cola_flops", y="test_acc", data=dff, style="expr0", hue="width", s=200)
# for idx, expr in enumerate(exprs):
#     aux = dff[dff["expr0"] == expr]
#     plt.scatter(aux["cola_flops"], aux["test_acc"], label=expr, marker=markers[idx % (len(markers))])
plt.ylabel("Test Acc")
plt.xlabel('FLOPs')
plt.xscale('log')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
