from matplotlib import pyplot as plt
import seaborn as sns
from experiments.fns import get_project_data
from experiments.fns import get_baselines
from experiments.fns import concat_and_format

df = get_project_data(project="inits")
df = df[df["state"] == "finished"]
df = df.dropna(subset=["train_loss"])


def rename(row):
    if row["struct"] == "ein_expr":
        text = row["expr0"] + " (RSGD)"
    elif row["struct"] == "ein_expr_bmm":
        text = row["expr0"] + " (BMM)"
    elif row["struct"] == "einsum_bmm":
        text = concat_and_format(row) + " (BMM)"
    elif row["struct"] == "einsum":
        text = concat_and_format(row)
    else:
        text = row["struct"]
    return text


df["vec"] = df.apply(rename, axis=1)
# exprs = ["0.74-0.04-0.22-0.0-0.71-0.27-0.02 (BMM)", "0.74-0.04-0.22-0.0-0.71-0.27-0.02 (RSGD)"]
exprs = ["0.38-0.29-0.33-0.0-0.04-0.35-0.61 (BMM)", "0.38-0.29-0.33-0.0-0.04-0.35-0.61 (RSGD)"]
# exprs = ["eg,bdg,adeg->abd (BMM)", "eg,bdg,adeg->abd (RSGD)"]
# exprs = ["fg,bdfg,dg->bd (BMM)", "fg,bdfg,dg->bd (RSGD)"]
# exprs = ["fg,dfg,adg->ad (BMM)", "fg,dfg,adg->ad (RSGD)"]
# exprs = ["e,bd,ade->abd (BMM)", "e,bd,ade->abd (RSGD)"]
df = df[df["vec"].isin(exprs)]

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")
plt.figure(dpi=100, figsize=(20, 10))
sns.lineplot(x="lr", y="train_loss", data=df, style="vec", hue="width")
sns.scatterplot(x="lr", y="train_loss", data=df, style="vec", hue="width", s=200)
plt.ylabel("Train Loss")
plt.xlabel("lr")
plt.xscale("log")
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()

df = df.loc[df.groupby(["vec", "width"])["train_loss"].idxmin()]
dfb = get_baselines()
exprs = ["fg,dfg,adg->ad (BMM)", "none", "btt"]
# exprs = ["btt", "none"]
dfb = dfb[dfb["vec"].isin(exprs)]

df2 = get_project_data(project="cifar_vecs")
df2 = df2[df2["state"] == "finished"]
df2 = df2[df2["struct"] == "einsum"]
df2["vec"] = df2.apply(concat_and_format, axis=1)
df2 = df2.dropna(subset=["train_loss"])
df2 = df2[df2["vec"].isin(["0.38-0.29-0.33-0.0-0.04-0.35-0.61"])]
# df2 = df2[df2["vec"].isin(["0.74-0.04-0.22-0.0-0.71-0.27-0.02"])]
df2 = df2.loc[df2.groupby(["vec", "width"])["train_loss"].idxmin()]

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
plt.figure(dpi=75, figsize=(25, 15))
sns.set_palette("Set2")
sns.scatterplot(x="cola_flops", y="train_loss", data=dfb, style="vec", s=200)
sns.lineplot(x="cola_flops", y="train_loss", data=dfb, style="vec")
sns.scatterplot(x="cola_flops", y="train_loss", data=df, style="vec", s=200)
sns.lineplot(x="cola_flops", y="train_loss", data=df, style="vec")
sns.scatterplot(x="cola_flops", y="train_loss", data=df2, style="vec", s=200)
sns.lineplot(x="cola_flops", y="train_loss", data=df2, style="vec")
plt.ylabel("Train Loss")
plt.xlabel('FLOPs')
plt.xscale('log')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
