from matplotlib import pyplot as plt
import seaborn as sns
from experiments.fns import get_project_data
from experiments.fns import get_baselines
from experiments.fns import rename_row
from experiments.fns import label_row

df = get_project_data(project="lr_vecs", steps=[-1])
df = df[df["epoch"] > 200]
df["vec"] = df.apply(rename_row, axis=1)
df["label"] = df.apply(label_row, axis=1)
target_var = "train_loss_avg"
print(df["vec"].unique())

exprs = ["0.3-0.0-0.0-0.7-0.0-0.3-0.7"]
exprs += ["0.7-0.0-0.0-0.3-0.0-0.7-0.3"]
# exprs += ["0.6-0.0-0.0-0.4-0.0-0.6-0.4"]
# exprs += ["0.67-0.0-0.0-0.33-0.0-0.67-0.33"]
exprs += ["0.33-0.0-0.0-0.67-0.0-0.33-0.67"]
exprs += ["0.4-0.0-0.0-0.6-0.0-0.4-0.6"]
# exprs += ["0.1-0.0-0.0-0.9-0.0-0.1-0.9"]
# exprs += ["0.0-0.4-0.6-0.0-0.6-0.0-0.4"]
# exprs += ["0.0-0.25-0.75-0.0-0.25-0.0-0.75"]
# exprs += ["0.0-0.9-0.1-0.0-0.9-0.0-0.1"]
# exprs += ["0.0-0.1-0.9-0.0-0.1-0.0-0.9"]
# exprs += ["0.0-0.2-0.8-0.0-0.2-0.0-0.8"]
# exprs += ["0.0-0.3-0.7-0.0-0.3-0.0-0.7"]
# exprs += ["0.0-0.4-0.6-0.0-0.6-0.0-0.4"]
exprs = [exp + " (BMM0) Adam" for exp in exprs]
df = df[df["vec"].isin(exprs)]

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")
plt.figure(dpi=100, figsize=(20, 10))
sns.lineplot(x="lr", y="train_loss", data=df, style="vec", hue="width")
sns.scatterplot(x="lr", y="train_loss", data=df, style="vec", hue="width", s=200)
plt.ylabel("Train Loss")
plt.xlabel("lr")
plt.xscale("log")
plt.ylim([1.70, 2.5])
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()

df = df.loc[df.groupby(["vec", "width"])["error"].idxmin()]

dfb = get_baselines(project="lr_baselines", target_var=target_var)
exprs = ["none"]
exprs += ["0.5-0.5-0.0-0.0-0.5-0.5-0.0"]
exprs += ["0.5-0.0-0.0-0.5-0.0-0.5-0.5"]
exprs = [exp + " (BMM0) Adam" for exp in exprs]
dfb = dfb[dfb["vec"].isin(exprs)]

x_var = "cola_flops"
# x_var = "cola_params"

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
plt.figure(dpi=75, figsize=(25, 15))
sns.set_palette("Set2")
sns.scatterplot(x=x_var, y=target_var, data=dfb, style="label", s=200)
sns.lineplot(x=x_var, y=target_var, data=dfb, style="label")
sns.scatterplot(x=x_var, y=target_var, data=df, style="label", s=200)
sns.lineplot(x=x_var, y=target_var, data=df, style="label")
plt.ylim([1.7, 2.2])
plt.ylabel("Train Loss" if target_var.startswith("train_loss") else "Error")
plt.xlabel("FLOPs" if x_var == "cola_flops" else "Params")
plt.xscale("log")
plt.yscale('log')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
