import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

np.random.seed(21)
N = 5
mult = 100
flops = mult * (np.arange(0, N) + 1)**2.
s1 = 1e4 / (flops**2.3)
s2 = 1e4 / (flops**2.0)
s3 = 1e4 / (flops[:3]**0.5)
s4 = 1e4 / (flops[:3]**0.75)
s5 = 1e4 / (flops**1.9)
s6 = 1e4 / (flops[:3]**1.5)
s7 = 1e4 / (flops[:3]**1.7)
cases = [(s1, "BTT"), (s2, "Dense"), (s3, "Kron"), (s4, "LowR"), (s5, "Monarch"), (s6, "TT"), (s7, "Banded")]

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
pal = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628']
sns.set_palette(sns.color_palette(pal))

plt.figure(dpi=100, figsize=(8, 6))
plt.title('MLP on CIFAR-10')
for sx, label in cases:
    plt.scatter(flops[:len(sx)], sx, label=label)
    plt.plot(flops[:len(sx)], sx)
plt.xlabel('TFLOPs')
plt.ylabel('Test Error')
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.tight_layout()
plt.savefig("perf_mlp.pdf")
plt.show()

flops = (np.arange(0, N) + 1)**8.
cases = [((1.2, -0.7), "BTT"), ((1.2, -0.6), "Dense"), ((1.4, -0.5), "Monarch"), ((2.3, -0.1), "Kron")]
plt.figure(dpi=100, figsize=(8, 6))
plt.title('ViT on ImageNet')
for sx, label in cases:
    beta, alpha = sx
    sx = np.exp(beta + alpha * np.log(flops))
    sx = sx[:3] if label == "Kron" else sx[:N]
    plt.plot(flops[:len(sx)], sx)
    plt.scatter(flops[:len(sx)], sx, label=label)
plt.xlabel('TFLOPs')
plt.ylabel('Test Error')
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.tight_layout()
plt.savefig("perf_vit.pdf")
plt.show()

flops = (np.arange(0, N) + 1)**10.
cases = [((1.2, -0.7), "BTT"), ((1.2, -0.6), "Dense"), ((1.4, -0.5), "Monarch")]
plt.figure(dpi=100, figsize=(8, 6))
plt.title('GPT on OpenWebText')
for sx, label in cases:
    beta, alpha = sx
    sx = np.exp(beta + alpha * np.log(flops))
    sx = sx[:3] if label == "Kron" else sx[:N]
    plt.plot(flops[:len(sx)], sx)
    plt.scatter(flops[:len(sx)], sx, label=label)
plt.xlabel('TFLOPs')
plt.ylabel('Test PPL')
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.tight_layout()
plt.savefig("perf_gpt.pdf")
plt.show()
