import torch
import pandas as pd
import numpy as np
from nn import cola_nn
import gc


def compute_elapsed_gpu(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    fn()
    end.record()
    torch.cuda.synchronize()
    elapsed = start.elapsed_time(end) / 1000
    return elapsed


def profile(model_fn, input_sizes, batch_size, device='cuda', trials=10, **hypers):
    profiling_results = []

    for input_size in input_sizes:
        try:
            # Generate the model for the current input size
            model = model_fn(input_size, **hypers).to(device).float()
            params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            # Generate a dummy input tensor based on the input_size specification
            X = torch.randn(batch_size, input_size).to(device).float()

            def fn():
                return model(X)

            compute_elapsed_gpu(fn)  # warmup
            torch.cuda.empty_cache() if device == 'cuda' else None
            for trial in range(trials):
                torch.cuda.reset_peak_memory_stats(device)
                runtime = compute_elapsed_gpu(fn)
                memory_usage = torch.cuda.max_memory_allocated(device) if device == 'cuda' else 0

                profiling_results.append({
                    'params': params,
                    'flops': params * batch_size,
                    'size': input_size,
                    'time': runtime,
                    'time_per_example': runtime / batch_size,
                    'memory': memory_usage,
                    'batch_size': batch_size,
                    'trial': trial
                })

                # Cleanup to reduce memory usage for the next iteration
                torch.cuda.empty_cache() if device == 'cuda' else None
            del model
            del X
            # garbage collection
            gc.collect()
        except Exception as e:
            print(e)
            break

    return pd.DataFrame(profiling_results)


btt_builder = lambda input_size, **kwargs: cola_nn.build_opt_btt(input_size, input_size, **kwargs)
kron_builder = lambda input_size, **kwargs: cola_nn.build_tt(input_size, input_size, tt_dim=2, tt_rank=1, **kwargs)
low_rank_builder = lambda input_size, **kwargs: cola_nn.build_low_rank(input_size, input_size, rank_frac=0, **kwargs)
blockdiag_builder = lambda input_size, **kwargs: cola_nn.build_blockdiag(input_size, input_size, **kwargs)
block_tt_builder = lambda input_size, **kwargs: cola_nn.build_block_tt(input_size, input_size, tt_dim=2, tt_rank=1, **kwargs)
dense_builder = lambda input_size, **kwargs: cola_nn.build_dense(input_size, input_size, **kwargs)
torch_dense_builder = lambda input_size, **kwargs: torch.nn.Linear(input_size, input_size)
monarch_4_builder = lambda input_size, **kwargs: cola_nn.build_monarch(input_size, input_size, **kwargs)
monarch_sqrt_builder = lambda input_size, **kwargs: cola_nn.build_monarch(input_size, input_size, num_blocks=-1, **kwargs)
builder = {
    'dense': dense_builder,
    'monarch': monarch_4_builder,
    # 'monarch-sqrt': monarch_sqrt_builder,
    # 'blockdiag': blockdiag_builder,
    'btt': btt_builder,
    'kron': kron_builder,
    'low_rank': low_rank_builder,
    # 'block_tt': block_tt_builder
}
# Profile the model for different input sizes
# input_sizes = np.geomspace(16, 1024, 30).astype(int) ** 2
input_sizes = 2**np.arange(4, 32, step=1)
# input_sizes = sorted([50_000, 10_000, 5_000, 1_000, 500, 100, 50, 10])
# input_sizes = sorted([250_000, 100_000, 50_000, 10_000, 5_000, 1_000, 500, 100, 50, 10])

# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True

max_flops = 1e13
for batch_size in [1024]:
    dfs = []
    for struct, model_fn in builder.items():
        expo = 2 if struct in ['dense', 'torch_dense', 'monarch-4'] else 3 / 2
        flops = [max(100, batch_size) * (d**expo) for d in input_sizes]
        ds = [d for d, f in zip(input_sizes, flops) if f <= max_flops]
        df = profile(model_fn, ds, batch_size=batch_size, device='cuda', trials=10)
        df['struct'] = struct
        dfs.append(df)
    df = pd.concat(dfs)

    # save the profiling results
    df.to_csv(f"timing_{batch_size}_single.csv", index=False)
