import time
import numpy as np
import torch
from nn import MLP
from nn.cola_nn import colafy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(seed=21)
# batch_size = 1024
# repeat_n = 1_000
batch_size = 68
repeat_n = 10
# dim_in, dim_out, depth, width = 32 * 32 * 3, 10, 9, 2 ** 13
dim_in, dim_out, depth, width = 32 * 32 * 3, 10, 9, 2 ** 8
model = MLP(dim_in, dim_out, depth, width)
# struct, layers, rank_frac, tt_dim, tt_rank = "none", "all_but_last", 0.1, 2, 1
# struct, layers, rank_frac, tt_dim, tt_rank = "low_rank", "all_but_last", 0.1, 2, 1
struct, layers, rank_frac, tt_dim, tt_rank = "block_tt", "all_but_last", 0.1, 2, 1
colafy(model, struct=struct, layers=layers, rank_frac=rank_frac, tt_dim=tt_dim, tt_rank=tt_rank)
model = torch.compile(model)
x = torch.randn(batch_size, dim_in)
x = x.to(device)
model = model.to(device)

times = np.zeros(repeat_n)
for idx in range(repeat_n):
    tic = time.time()
    y = model(x)
    loss = torch.mean(y)
    loss.backward()
    toc = time.time()
    times[idx] = toc - tic
print(f"Struct: {struct}")
print(f"Mean: {np.mean(times[1:]):1.3e} over {repeat_n:,d} times | Total {np.sum(times[1:]):1.3e}")
