import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from trainkit.saving import load_object

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})

data = load_object("./logs/bench.pkl")
df = pd.DataFrame(data, columns=["struct", "width", "device", "mean", "sterr"])
struct_s = list(df["struct"].unique())
device = list(df["device"].unique())[0]
colors = ["#7570b3", "#1b9e77", "#8c510a"]
colors = {stru: col for stru, col in zip(struct_s, colors)}

plt.figure(dpi=100, figsize=(10, 8))
plt.title(f"Device {device}")
for struct in struct_s:
    dff = df[df["struct"] == struct]
    plt.plot(dff["width"], dff["mean"], label=struct, color=colors[struct])
    # plt.scatter(dff["width"], dff["mean"])
    plt.errorbar(dff["width"], dff["mean"], dff["sterr"], color=colors[struct])
# for i, pos in enumerate(np.linspace(-7, 4, 20)):
#     plt.plot(dff["width"], 2**pos * np.array(dff["width"]) ** 2,
#              ls='--', color='gray', label='N^2' if not i else '')
plt.xlabel("Width")
plt.ylim([np.min(df["mean"]) * 0.8, np.max(df["mean"]) * 1.2])
plt.xscale("log")
plt.yscale("log")
plt.ylabel("Time")
plt.tight_layout()
plt.legend()
plt.savefig("./logs/time.png")
plt.show()
