import os
import yaml

import hydra

import torch

from utils.utils import *

import models
from train_utils import *
from data_utils import shapes_dict

def compute_fro(model, args, model_diff=None):
    fd_names = os.listdir(args.ckpt_path)
    ckpt_names = sorted([fd for fd in fd_names if os.path.isfile(os.path.join(args.ckpt_path, fd)) and fd.startswith('epochs') and fd != 'epochs=0.pt'], key=lambda x: int(x[7:-3]))

    norms = []

    if model_diff == None:
        for i,ckpt_name in enumerate(ckpt_names):
            model_dict = torch.load(os.path.join(args.ckpt_path, ckpt_name))['last']
            model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k})

            total_norm = 0

            for param in model.parameters():
                param_norm = torch.norm(param)
                total_norm += param_norm.item() ** 2 

            total_norm = total_norm ** 0.5

            print(f'norm: {total_norm:.2f} ({i+1}/{len(ckpt_names)})')
            norms.append(total_norm)
    else:
        for i,ckpt_name in enumerate(ckpt_names):
            model_dict = torch.load(os.path.join(args.ckpt_path, ckpt_name))['last']
            model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k})

            diff_model_dict = torch.load(os.path.join(args.diff_ckpt_path, ckpt_name))['last']
            model_diff.load_state_dict({k: v for k, v in diff_model_dict.items() if 'model_preact_hl1' not in k})

            total_norm = 0
            for param, diff_param in zip(model.parameters(), model_diff.parameters()):
                param_norm = torch.norm(param-diff_param)
                total_norm += param_norm.item() ** 2 

            total_norm = total_norm ** 0.5

            print(f'norm: {total_norm:.2f} ({i+1}/{len(ckpt_names)})')
            norms.append(total_norm)

    return norms

@hydra.main(version_base=None, config_path='.', config_name='trajectory_params')
def main(cfg):
    # Remove hydra logger
    os.remove(f'{os.path.splitext(os.path.basename(__file__))[0]}.log')
    os.umask(0)

    # Argument intake and validation
    args = DotDict(cfg)

    assert args.exp_name and args.run_name and \
           'Must specify experiment and run name.'
    if args.traj:
        assert args.norm == 'fro', \
        'Must use frobenius norm if computing trajectories.'
    
    args.ckpt_path = os.path.join(args.team_path, args.exp_name, args.run_name, 'ckpt')
    args.traj_path = os.path.join(args.team_path, args.exp_name, args.run_name, 'trajectory')
    os.makedirs(args.traj_path, mode=0o777, exist_ok=True)

    if args.diff_exp_name and args.diff_run_name:
        args.diff_ckpt_path = os.path.join(args.team_path, args.diff_exp_name, args.diff_run_name, 'ckpt')

    log_std(args.traj_path, incl_stderr=False)

    args.run_wandb_config_path = os.path.join(args.team_path, args.exp_name, args.run_name,
                                              'wandb',
                                              'latest-run',
                                              'files',
                                              'config.yaml'
                                              )
    with open(args.run_wandb_config_path, 'r') as f:
        wandb_config = yaml.safe_load(f)
    wandb_config.pop('_wandb', None)
    wandb_config.pop('wandb_version', None)
    run_args = DotDict({k:v['value'] for k,v in wandb_config.items()})

    if run_args.dataset == 'cifar10':
        num_classes = 10
    elif run_args.dataset == 'cifar100':
        num_classes = 100
    elif run_args.dataset == 'tiny_imagenet':
        num_classes = 200
    else:
        raise Exception('Dataset not supported.')

    model = models.get_model(run_args.model,
                            num_classes,
                            False,
                            shapes_dict[run_args.dataset],
                            run_args.model_width,
                            'relu',
                            droprate=run_args.droprate)
    
    # DIFF
    model_diff = None
    if args.diff_exp_name and args.diff_run_name:
        args.run_diff_wandb_config_path = os.path.join(args.team_path, args.diff_exp_name, args.diff_run_name,
                                                'wandb',
                                                'latest-run',
                                                'files',
                                                'config.yaml'
                                                )
        with open(args.run_diff_wandb_config_path, 'r') as f:
            wandb_config = yaml.safe_load(f)
        wandb_config.pop('_wandb', None)
        wandb_config.pop('wandb_version', None)
        run_args_diff = DotDict({k:v['value'] for k,v in wandb_config.items()})

        if run_args_diff.dataset == 'cifar10':
            num_classes_diff = 10
        elif run_args_diff.dataset == 'cifar100':
            num_classes_diff = 100
        elif run_args_diff.dataset == 'tiny_imagenet':
            num_classes_diff = 200
        else:
            raise Exception('Dataset not supported.')
        
        assert run_args_diff.model == run_args.model and \
               num_classes_diff == num_classes and \
               run_args_diff.model_width == run_args.model_width and \
               run_args_diff.droprate == run_args.droprate and \
               run_args_diff.epochs == run_args.epochs, \
               'Run models must match.'

        model_diff = models.get_model(run_args_diff.model,
                                num_classes_diff,
                                False,
                                shapes_dict[run_args_diff.dataset],
                                run_args_diff.model_width,
                                'relu',
                                droprate=run_args_diff.droprate)

    if args.norm == 'l1':
        model_dict = torch.load(os.path.join(args.ckpt_path, 'best_test_err.pt'))['last']
        model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k})
        print(f'l1_norm: {sum(p.abs().sum() for p in model.parameters())}')
    elif args.norm == 'fro' and args.traj:
        norms = compute_fro(model, args, model_diff)
        if args.diff_exp_name and args.diff_run_name:
            torch.save(norms, os.path.join(args.traj_path, f'norms_diff_{args.diff_exp_name}_{args.diff_run_name}.pt'))
        else:
            torch.save(norms, os.path.join(args.traj_path, 'norms.pt'))
    elif args.norm == 'fro' and not args.traj:
        model_dict = torch.load(os.path.join(args.ckpt_path, 'best_test_err.pt'))['last']
        model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k})
        print(f'fro_norm: {sum((p**2).sum() for p in model.parameters())**0.5}')
    elif args.norm == 'nuc':
        nuclear_norm_total = 0.0
        model_dict = torch.load(os.path.join(args.ckpt_path, 'best_test_err.pt'))['last']
        model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k})

        num_params = sum(1 for _ in model.named_parameters())
        for _, param in tqdm(model.named_parameters(), desc='Aggregating norm...', total=num_params):
            if len(param.size()) == 2:
                nuclear_norm_total += torch.norm(param, 'nuc').item()
            else:
                reshaped_param = param.view(param.size(0), -1)
                nuclear_norm_total += torch.norm(reshaped_param, 'nuc').item()

        print(f'nuc_norm: {nuclear_norm_total}')
    else:
        raise Exception('Your norm hasn\'t been implemented!')

if __name__ == '__main__':
    main()
