import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
pal = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628']
sns.set_palette(sns.color_palette(pal))

N = 3
width = np.array([16, 64, 256, 1024, 4096])
results = {
    "Monarch": [
        np.array([75, 76, 77, 78, 79]),
        np.array([79, 70, 71, 72, 73]),
    ],
    "BTT": [
        np.array([76, 77, 78, 79, 71]),
        np.array([77, 78, 79, 80, 81]),
    ],
    "LowR": [
        np.array([68, 69, 70, 71, 72]),
        np.array([71, 72, 73, 74, 75]),
    ],
    "Kron": [
        np.array([63, 64, 65, 66, 67]),
        np.array([67, 68, 69, 70, 71]),
    ],
}
cases = ["ours", "naive"]
markers = ["P", "v"]
pal_map = {"BTT": 0, "Dense": 1, "Monarch": 2, "LowR": 3, "Kron": 4}
shift_map = {"BTT": -5, "Dense": 0, "Monarch": 5, "LowR": -7.5, "Kron": 7.5}

plt.figure(dpi=100, figsize=(8, 6))
for label, val in results.items():
    for idx, y in enumerate(val):
        plt.scatter(width, y, label=f"{label}({cases[idx]})", c=pal[pal_map[label]], marker=markers[idx])
plt.ylabel('Test Accuracy (%)')
plt.xlabel('Width')
plt.xscale('log')
plt.tight_layout()
plt.savefig("lr_gen_mlp.pdf")
plt.show()

results = {
    "Kron": [np.array([62., 63., 66.5, 67, 68]), np.array([65., 67., 68, 69, 70])],
    "BTT": [np.array([74., 75., 77., 78, 79]), np.array([74.3, 77.3, 80., 81, 82])],
    "Monarch": [np.array([70., 72.5, 75.5, 76, 77]), np.array([74., 76.3, 77., 78, 79])],
}
cases = ["ours", "naiv"]
markers = ["P", "v"]
shift_map = {"BTT": -5, "Dense": 0, "Monarch": 5, "LowR": -7.5, "Kron": 7.5}

plt.figure(dpi=100, figsize=(8, 6))
for label, val in results.items():
    for idx, y in enumerate(val):
        x = width + shift_map[label]
        plt.scatter(x, y, label=f"{label}({cases[idx]})", c=pal[pal_map[label]], marker=markers[idx])
plt.ylabel('Test Accuracy (%)')
plt.xlabel('Width')
plt.xscale('log')
# plt.legend()
plt.tight_layout()
plt.savefig("lr_gen_vit.pdf")
plt.show()
