import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from trainkit.saving import load_object

# path_s = ["./logs/mvm_dense_cpu.pkl", "./logs/mvm_btt_cpu.pkl"]
path_s = ["./logs/mvm_dense_cuda.pkl", "./logs/mvm_btt_cuda.pkl"]
output_path = "./logs/plot_mvm.png"

colors = ["#b2182b", "#999999", "#ef8a62", "#4d4d4d", "#fddbc7"]
assert len(colors) >= len(path_s), "missing colors"
columns = ["struct", "dim_in", "dim_out", "batch_size", "flops", "device", "mean", "sterr"]
data = load_object(path_s[0])
df = pd.DataFrame(data, columns=columns)
for path in path_s[1:]:
    data = load_object(path)
    df_new = pd.DataFrame(data, columns=columns)
    df = pd.concat((df, df_new))

struct_s = list(df["struct"].unique())
device = list(df["device"].unique())[0]
colors = {stru: col for stru, col in zip(struct_s, colors)}
sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
plt.figure(dpi=100, figsize=(10, 8))
plt.title(f"Device {device}")
for struct in struct_s:
    dff = df[df["struct"] == struct]
    plt.plot(dff["flops"], dff["mean"], color=colors[struct])
    plt.scatter(dff["flops"], dff["mean"], label=struct, color=colors[struct])
plt.ylim([np.min(df["mean"]) * 0.8, np.max(df["mean"]) * 1.2])
plt.xlabel("FLOPs")
plt.ylabel("Time (sec)")
plt.xscale("log")
plt.yscale("log")
plt.tight_layout()
plt.legend()
plt.savefig(output_path)
plt.show()
