import os
import time
import json
from math import prod
import wandb
import torch
from torch.nn import CrossEntropyLoss, MSELoss
from model import set_seed
from nn.cola_nn import cola_parameterize, get_model_summary_and_flops, replace_layers_with_cola_layers
from learning.fns import select_factorizer
import nn
from tqdm import tqdm
from scaling_mlps.data_utils import data_stats
from scaling_mlps.data_utils.dataloader import get_loader
from scaling_mlps.utils.config import config_to_name
from scaling_mlps.utils.get_compute import get_compute
from scaling_mlps.utils.metrics import topk_acc, real_acc, AverageMeter
from scaling_mlps.utils.parsers import get_training_parser
from spline_expressivity.utils import compute_LC
from synthetic_data.utils_torch import get_random_f, get_sine_f, get_teacher_f
import math
import pickle
import numpy as np

SCHEDULERS = ["cosine", "none"]

def get_scheduler(opt, scheduler_name, step_size):
    """Return scheduler class."""
    scheduler_name = scheduler_name.lower()

    if scheduler_name not in SCHEDULERS:
        raise ValueError(f"Scheduler {scheduler_name} not supported.")

    if scheduler_name == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=step_size)
    elif scheduler_name == "none":
        return torch.optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=1.0)
    else:
        raise ValueError(f"Scheduler {scheduler_name} not supported.")



@torch.no_grad()
def test(model, loader, loss_fn, args):
    start = time.time()
    model.eval()
    total_loss = AverageMeter()

    # freq_shifted = torch.fft.fftshift(freq)
    # grid = loader.grid

    # magn_true = np.zeros_like(torch.fft.rfftfreq(len(grid), grid[1] - grid[0]))
    # magn_pred = np.zeros_like(torch.fft.rfftfreq(len(grid), grid[1] - grid[0]))
    n=0
    def get_fft(y):
        fft = torch.fft.rfftn(y)
        # fft_shifted = torch.fft.fftshift(fft)
        magnitude = torch.abs(fft)
        return magnitude

    with tqdm(total=len(loader)) as pbar:
        for x, y in loader:
            pbar.update(1)
            y_hat = model(x.reshape(-1, x.shape[-1])).reshape(x.shape[:-1])


            # magn_y = get_fft(y)
            # magn_y_hat = get_fft(y_hat)

            # magn_true += magn_y.detach().cpu().numpy()
            # magn_pred += magn_y_hat.detach().cpu().numpy()
            n += 1
            # from IPython import embed
            # embed() or exit(0)
            loss = loss_fn(y_hat.squeeze(), y.squeeze()).item()

            total_loss.update(loss)

    end = time.time()

    return (
        total_loss.get_avg(percentage=False),
        # magn_true / n,
        # magn_pred / n,
        end - start,
    )


def main(args):
    set_seed(args.seed)
    # Use mixed precision matrix multiplication
    torch.backends.cuda.matmul.allow_tf32 = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_builder = getattr(nn, args.model)
    base_ffn_expansion = 4  # equivalent to specifying a constant LR multuplier in μP. 4 works well for ViT.
    base_config = dict(dim_in=args.input_dim, dim_out=args.num_classes, depth=args.depth, width=args.width,
                       ffn_expansion=base_ffn_expansion, patch_size=args.patch_size, in_channels=args.in_channels,
                       shuffle_pixels=args.shuffle_pixels, heads=args.heads, dim_head=args.dim_head, attn_mult=args.attn_mult,
                       output_mult=args.output_mult, emb_mult=args.emb_mult,
                       layer_norm=args.layer_norm)
    # target config
    target_config = base_config.copy()
    args.width = int(args.width * args.scale_factor)  # we update args.width to have it logged in wandb
    target_config['width'] = args.width
    target_config['ffn_expansion'] = args.ffn_expansion

    # additional LR multipliers
    def extra_lr_mult_fn(param_name):
        if 'to_patch_embedding' in param_name or 'input_layer' in param_name:
            return args.input_lr_mult
        elif 'matrix_params.0' in param_name:
            print(f'scaling {param_name} LR by {args.lr_mult_1}')
            return args.lr_mult_1
        elif 'matrix_params.1' in param_name:
            print(f'scaling {param_name} LR by {args.lr_mult_2}')
            return args.lr_mult_2
        else:
            return 1

    def extra_init_mult_fn(param_name):
        if 'matrix_params.0' in param_name:
            print(f'scaling {param_name} std by {args.init_mult_1}')
            return args.init_mult_1
        elif 'matrix_params.1' in param_name:
            print(f'scaling {param_name} std by {args.init_mult_2}')
            return args.init_mult_2
        else:
            return 1

    def zero_init_fn(weight, name):
        return hasattr(weight, 'zero_init') and weight.zero_init

    # CoLA structure
    struct = args.struct
    fact_cls = None
    if struct == "einsum":
        fact_cls = select_factorizer(name=args.fact)
        fact_cls = fact_cls(cores_n=args.cores_n, int_pow=args.int_pow)
        fact_cls.sample()
    cola_kwargs = dict(tt_dim=args.tt_dim, tt_rank=args.tt_rank, num_blocks=args.num_blocks, rank_frac=args.rank_frac,
                       fact_cls=fact_cls)
    # initialize scaled up model with some linear layers replaced by cola layers,
    # and create optimizer with appropriately scaled learning rates
    if args.use_wrong_mult:
        print("#### WARNING: using wrong mult ####")

    # teacher_model, _ = cola_parameterize(model_builder, base_config, args.lr, struct=args.dataset_name, layer_select_fn=args.layers, zero_init_fn=zero_init_fn, extra_lr_mult_fn=extra_lr_mult_fn,
    #                             device=device, cola_kwargs=cola_kwargs, use_wrong_mult=args.use_wrong_mult)

    # teacher_model.eval()
    model, opt = cola_parameterize(model_builder, base_config, args.lr, target_config=target_config, struct=struct,
                                   layer_select_fn=args.layers, zero_init_fn=zero_init_fn, extra_lr_mult_fn=extra_lr_mult_fn,
                                   device=device, cola_kwargs=cola_kwargs, use_wrong_mult=args.use_wrong_mult, optim_kwargs={"weight_decay": args.weight_decay})
    fake_input = torch.randn(1, args.input_dim).to('cuda')
    info = get_model_summary_and_flops(model, fake_input)
    if struct == "einsum":
        info["cola_flops"] += fact_cls.flops
        print(fact_cls.layers)

    scheduler = get_scheduler(opt, args.scheduler, args.maximum_steps)
    loss_fn = MSELoss()

    # Create unique identifier
    run_name = config_to_name(args)
    path = os.path.join(args.checkpoint_folder, run_name)

    # Create folder to store the checkpoints
    if not os.path.exists(path):
        os.makedirs(path)
        with open(path + '/config.txt', 'w') as f:
            json.dump(args.__dict__, f, indent=2)

    # Get the dataloaders
    local_batch_size = args.batch_size // args.accum_steps

    def get_batch(data_X, data_y, index, batch_size, device):
        start_idx, end_idx = index*batch_size, (index+1)*batch_size
        end_idx = min(len(data_X), end_idx)
        x = torch.from_numpy(data_X[start_idx: end_idx])
        y = torch.from_numpy(data_y[start_idx: end_idx])

        # if "cuda" in device:
        x = x.pin_memory().to(device, non_blocking=True)
        y = y.pin_memory().to(device, non_blocking=True)
        # x = x.to(device)
        # y = y.to(device)
        
        return x, y

    class NpyLoader:
        def __init__(
                self,
                batch_size,
                device,
                freq_std,
                input_dim,
                feature_num,
                num_waves,
                f_random_state=42,
                data_random_state=40,
                length=-1,
            ):            
            self.batch_size = batch_size
            self.input_dim = input_dim
            self.device = device
            self.data_random = data_random_state

            self.rng = torch.Generator(device)
            self.rng.manual_seed(f_random_state)

            if freq_std != -1:
                self.f = get_sine_f(feature_num, num_waves=num_waves,freq_std=freq_std, rng=self.rng, device=device)
                print("Using Sine waves as test function")
            else:
                self.f = get_teacher_f(feature_num, [1024 for _ in range(6)], self.rng, device)
                print("Using teacher model as test function")
            
            self.length = length
            self.infty = (length == -1)

            self.rng.manual_seed(data_random_state)
            
                

        def __iter__(self):
            self.l = self.length
            while True:
                if not self.infty:
                    if self.l <= 0:
                        break
                    else:
                        self.l -= 1
                X = torch.rand(self.batch_size, self.input_dim, generator=self.rng, device=self.device, dtype=torch.float32) - 0.5
                y = self.f(X)
                yield [X,y]



        def __len__(self):
            return self.length
                    
        def reset(self):
            self.rng.manual_seed(self.data_random_state)


    train_loader = NpyLoader(args.batch_size, device, args.freq_std, args.input_dim, args.feature_num, args.num_waves, data_random_state=0)
    test_loader = NpyLoader(args.batch_size, device, args.freq_std, args.input_dim, args.feature_num, args.num_waves, data_random_state=1, length=args.test_size)

    it = iter(train_loader)
    x, y = next(it)
    g = test_loader.f
    assert (y-g(x)).norm() < 1e-5, f"Train and test have different test function: error {(y-g(x)).norm()}"


    if args.wandb:
        config = args.__dict__
        # config['freq'] = freq
        config.update(info)
        if struct == "einsum":
            config.update(fact_cls.log_data())
            exprs = fact_cls.get_unique_ein_expr()
            print(exprs)
            formated_combined = {f"expr{idx}": f"{key}({val:d})" for idx, (key, val) in enumerate(exprs.items())}
            config.update(formated_combined)
        wandb.init(
            project=args.wandb_project,
            name=run_name,
            config=config,
            tags=["pretrain", args.dataset],
        )


    prev_hs = None
    
    train_loss_meter = AverageMeter()

    pb = tqdm(total=args.maximum_steps)
    for step, (x, y) in (enumerate(train_loader)):

        model.train()
        y_hat = model(x)
        loss = loss_fn(y_hat.squeeze(), y.squeeze())

        # print(y_hat.dtype, y.dtype, loss.dtype, x.dtype)
        # print(y_hat.device, y.device, loss.device, x.device)

        loss = loss
        loss.backward()

        if args.clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        opt.step()
        opt.zero_grad()
        train_loss_meter.update(loss.item(), y.shape[0])

        
        scheduler.step()

        if step % args.calculate_stats == 0:
            # if step > 0:
            #     x=torch.rand((100, 20), device=device)*0.
            #     x[:, 1] += torch.linspace(-0.5, 0.5, 100, device=device)
            #     f=train_loader.f
            #     y=f(x)
            #     b=model(x)
            #     print(y.reshape(-1))
            #     print(b.reshape(-1))
            #     from IPython import embed
            #     embed() or exit(0)
            train_loss = train_loss_meter.get_avg(percentage=False)
            train_loss_meter.reset()

            model.clear_features()
            test_loss, test_time = test(model, test_loader, loss_fn, args)
            # get features on test set
            # hs = model.hs  # list of lists of tensors
            hs = model.get_features()
            hs = [torch.cat(h.buffer, dim=0) for h in hs]  # list of tensors
            if prev_hs is None:
                prev_hs = hs
            dhs = [hs[i] - prev_hs[i] for i in range(len(hs))]
            h_norm = [torch.norm(h, dim=1).mean() / h.shape[1]**0.5 for h in hs]  # should be O(1)
            dh_norm = [torch.norm(dh, dim=1).mean() / dh.shape[1]**0.5 for dh in dhs]  # should be O(1)
            prev_hs = hs

            # Compute local complexity
            # train_LC is not computed now TODO:
            # stats = compute_LC(model, train_loader, test_loader)
            
            if args.wandb:
                logs = {
                    "step": step,
                    "compute": step * info['cola_flops'] * args.batch_size,
                    "test_loss": test_loss,
                    "train_loss": train_loss,
                    "Inference time": test_time,
                    # 'magn_true': magn_true,
                    # 'magn_pred': magn_pred,
                }
                # Compute Local Complexity
                # for k,v in stats.items():
                    # sum over all layers and average over batches
                    # logs[k] = v.sum(1).mean(0) 
                for i in range(len(h_norm)):
                    logs[f'h_{i}'] = h_norm[i].item()
                    logs[f'dh_{i}'] = dh_norm[i].item()
                # go through all params
                for name, p in model.named_parameters():
                    if hasattr(p, 'rms'):
                        logs[f'rms/{name}'] = p.rms
                # for i, freq in enumerate(config['freq']):
                #     if freq > 100: break
                #     # print(f"Freq {freq} magn_true {magn_true[i]}")
                #     logs[f'magn_true/w_{freq}'] = magn_true[i]
                #     logs[f'magn_pred/w_{freq}'] = magn_pred[i]
                wandb.log(logs)

            pb.set_description(f"Steps {step}, Train Loss: {train_loss:.2f}, Test Loss: {test_loss:.2f}")

        pb.update(1)
        if step >= args.maximum_steps:
            break

    if args.save:
        torch.save(
            model.state_dict(),
            path + "/final_checkpoint.pt",
        )


if __name__ == "__main__":
    parser = get_training_parser()
    args = parser.parse_args()

    if "synthetic" in args.dataset:
        args.dataset, args.dataset_name = args.dataset.split("_")
    args.num_classes = data_stats.CLASS_DICT[args.dataset]

    if args.n_train is None:
        args.n_train = data_stats.SAMPLE_DICT[args.dataset]

    if args.crop_resolution is None:
        args.crop_resolution = args.resolution

    main(args)
