from matplotlib import pyplot as plt
import seaborn as sns
from experiments.fns import get_project_data

x_var = "epoch"
steps = [idx for idx in range(100)]
df = get_project_data(project="rsgd", steps=steps, only_finished=False)
df = df[["epoch", "train_loss", "struct", "width", "depth", "lr", "cola_flops"]]
target_var = "train_loss"
df = df[df["width"] == 4096]
df = df[df["lr"] == 30]

sns.set(style="whitegrid", font_scale=5.0, rc={"lines.linewidth": 3.0})
plt.figure(dpi=100, figsize=(20, 18))
colors = sns.color_palette("Set2")
labels = {"rbtt": "BTT(RSGD)", "btt": "BTT(muP)"}
for idx, (key, la) in enumerate(labels.items()):
    dff = df[df["struct"] == key]
    plt.scatter(dff[x_var], dff[target_var], c=colors[idx], s=200, label=la)
    plt.plot(dff[x_var], dff[target_var], c=colors[idx], lw=3.0)
plt.ylim([1.9, 2.3])
plt.ylabel("Train Loss" if target_var.startswith("train_loss") else "Test Error")
plt.xlabel("Epochs")
# plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.legend()
plt.savefig("logs/rsgd_mu_mlp_cifar10.pdf")
plt.tight_layout()
plt.show()
