import time
import math
import torch
import numpy as np
import cola
from trainkit.saving import save_object
from trainkit.timing import print_time_taken

save_output = True
Ns = [10_000, 5_000, 1_000, 500, 100, 50, 10]
output_path = "./logs/timings.pkl"
torch.manual_seed(seed=21)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

results = []
t0 = time.time()
for N in Ns:
    print(f"Going over: {N:,d}")
    Nsq = math.floor(N**0.5)
    x = torch.ones((N, 784 * 5))
    x = x.to(device)
    repeat_n = 5

    r = int(N * 0.2)
    L = cola.ops.Dense(torch.randn(size=(N, r)).to(device))
    U = cola.ops.Dense(torch.randn(size=(r, N)).to(device))
    LU = L @ U
    times = np.zeros(shape=(repeat_n, ))
    for idx in range(repeat_n):
        tic = time.time()
        aux = LU @ x
        toc = time.time()
        times[idx] = toc - tic
    out = {"time": times, "name": "lowr", "size": aux.shape[0], "device": LU.device.type}
    results.append(out)
    del LU

    D = cola.ops.Dense(torch.randn(size=(N, N)).to(device))
    times = np.zeros(shape=(repeat_n, ))
    for idx in range(repeat_n):
        tic = time.time()
        aux = D @ x
        toc = time.time()
        times[idx] = toc - tic
    out = {"time": times, "name": "dense", "size": aux.shape[0], "device": D.device.type}
    results.append(out)
    del D

    K1 = torch.randn(size=(Nsq, Nsq)).to(device)
    K2 = torch.randn(size=(Nsq, Nsq)).to(device)
    K = cola.ops.Kronecker(K1, K2)
    times = np.zeros(shape=(repeat_n, ))
    for idx in range(repeat_n):
        tic = time.time()
        aux = K @ x[:K.shape[0]]
        toc = time.time()
        times[idx] = toc - tic
    out = {"time": times, "name": "kron", "size": aux.shape[0], "device": K.device.type}
    results.append(out)

t1 = time.time()
print_time_taken(t1 - t0)

if save_output:
    save_object(results, output_path)
