import sys
import time
from datetime import timedelta
import argparse
import wandb
import optuna
from sqlalchemy import URL
import logging

from tqdm import tqdm
import jax
from jax import random
import jax.numpy as jnp

import tensorflow as tf
tf.config.experimental.set_visible_devices([], 'gpu')

from models.trainer import SSLTrainState, train_step, save_state, load_state
from models.utils import get_optimizer, get_current_lr, pretty_print
from models.losses import get_training_loss

from data.loaders import get_ssl_cifar_loaders, get_ssl_imagenet_loaders, get_ssl_cifar100_loaders
from data.loaders import get_ssl_im32_loaders, get_ssl_stl10_loaders, get_ssl_tinyimagenet_loaders
from data.loaders import replicate, unreplicate, get_iter
from metrics.lep import eval_step

from networks.vgg import VGG8
from networks.resnet import ResNet50, ResNet18
from models.ssl_base import ssl_agent, supervised_agent


def main(args):

    if not args.nolog:
        wandb.init(project="isoloss", entity="username", name=args.experiment_name, config=args)

    jax.config.update("jax_default_device", jax.devices()[args.device]) 



    # get data
    if args.dataset == 'cifar10':
        ssl_ds, lep_ds, test_ds = get_ssl_cifar_loaders(args.bsz)
        num_classes = 10
    if args.dataset == 'cifar100':
        ssl_ds, lep_ds, test_ds = get_ssl_cifar100_loaders(args.bsz)
        num_classes = 100
    elif args.dataset == 'stl10':
        ssl_ds, lep_ds, test_ds = get_ssl_stl10_loaders(args.bsz)
        num_classes = 10
    elif args.dataset == 'imagenet':
        ssl_ds, lep_ds, test_ds = get_ssl_imagenet_loaders(args.bsz)
        num_classes = 1000
    elif args.dataset == 'imagenet32':
        ssl_ds, lep_ds, test_ds = get_ssl_im32_loaders(args.bsz)
        num_classes = 1000
    elif args.dataset == 'tinyimagenet':
        ssl_ds, lep_ds, test_ds = get_ssl_tinyimagenet_loaders(args.bsz)
        num_classes = 200



    # get the encoder
    if args.encoder == 'vgg':
        encoder = VGG8
        out_channels = 512
    elif args.encoder == 'resnet50':
        encoder = ResNet50
        out_channels = 2048
    elif args.encoder == 'resnet18':
        encoder = ResNet18
        out_channels = 512


    dp_args = {}
    if args.pred=='dp':
        for k, v in args.__dict__.items():
            if k.startswith('dp_'):
                new_k = '_'.join(k.split('_')[1:])
                dp_args[new_k] = v

    num_proj_layers = len(args.proj_layers.split(','))
    proj_args = {'hidden_sizes': tuple(int(x) for x in args.proj_layers.split(',')),
                #  'bnorm': tuple([True]*num_proj_layers)}
                'bnorm': tuple([True]*(num_proj_layers-1) + [False])}
    if args.proj=='id':
        proj_args['hidden_sizes'] = (out_channels,)

    # get the agent
    if args.loss=='supervised':
        assert args.pred=='id', 'supervised learning only supports id prediction'
        agent = supervised_agent(greedy=args.greedy, encoder=encoder,
                                proj_type=args.proj, proj_args=proj_args,
                                pred_type='id', num_classes=num_classes, iso=args.iso, 
                                dp_args={}, dataset=args.dataset)
    else:
        agent = ssl_agent(greedy=args.greedy, encoder=encoder,
                          proj_type=args.proj, proj_args=proj_args,
                          pred_type=args.pred, dp_args=dp_args, iso=args.iso,
                          num_classes=num_classes, dataset=args.dataset)


    # initialize the agent
    rng = random.PRNGKey(args.seed)
    rng, param_key = random.split(rng, 2)
    im1, im2, y = next(iter(ssl_ds))

    variables = agent.init(param_key, im1, train=False).unfreeze()
    params = variables['params']
    pretty_print(params)
    batch_stats = variables['batch_stats']


    if args.pred=='dp':
        dp_aux = variables['direct_pred']
        pretty_print(dp_aux)
    else:
        dp_aux = None


    # get the target network
    if args.loss in ['simsiam', 'directloss']:
        if args.ema:
            print('\nUsing EMA\n')
            rng, param_key = random.split(rng, 2)
            target_variables = agent.init(param_key, im1, train=False, is_target_net=True).unfreeze()
            target_params = target_variables['params']
            tg_batch_stats = target_variables['batch_stats']
        else:
            print('\nNot using EMA\n')
            target_params = params
            tg_batch_stats = batch_stats
    else:
        target_params = None
        tg_batch_stats = None


    # initialize the state
    opt = get_optimizer(args, len(ssl_ds))
    loss_fn = get_training_loss(args, num_classes)
    state = SSLTrainState.create(apply_fn=agent.apply, 
                                params=params, batch_stats=batch_stats,
                                target_params=target_params, tg_batch_stats=tg_batch_stats,
                                tx=opt, direct_pred=dp_aux)

    if args.load_state:
        state = load_state(state, args.model_name, args.load_epoch)

    if args.parallel:
        state = replicate(state)
        train_step_fn = jax.pmap(train_step, axis_name='device',
                            static_broadcasted_argnums=(2,3))
    else:
        train_step_fn = train_step

    # initial evaluation
    if not args.optuna:
        do_offline_lep = args.load_epoch == 0 if args.load_state else True
        metrics = eval_step(state, lep_ds, test_ds, 0, num_classes=num_classes, do_offline_lep=do_offline_lep, parallel=args.parallel)
        metrics['epoch'] = args.load_epoch if args.load_state else 0

        if not args.nolog:
            metrics = jax.tree_map(lambda x: x.item() if isinstance(x, jnp.ndarray) and x.size==1 else x, metrics)

            if args.log_eigvals:
                eigvals = metrics['eigvals'].reshape(1, -1).tolist()
                eigvals_proj = metrics['eigvals_proj'].reshape(1, -1).tolist()
                eigval_table = wandb.Table(data=eigvals, columns=['eigval_{}'.format(i) for i in range(1, len(eigvals[0])+1)])
                eigval_proj_table = wandb.Table(data=eigvals_proj, columns=['eigval_proj_{}'.format(i) for i in range(1, len(eigvals_proj[0])+1)])
            del metrics['eigvals'], metrics['eigvals_proj']
            
            wandb.log(metrics)

        # sort the metrics in alphabetical order for printing
        print('\n' + '\t\t'.join(f'{k}: {v:.3f}' for k, v in sorted(metrics.items()) if k not in ['eigvals', 'eigvals_proj']) + '\n')

    # training loop
    tot_steps = args.num_epochs*len(ssl_ds)
    for epoch in range(1 + args.load_epoch, args.num_epochs + 1):
        
        ssl_it = get_iter(ssl_ds, args.parallel)

        start_time = time.perf_counter()
        train_metrics = {}

        for idx, batch in tqdm(enumerate(ssl_it), total=len(ssl_ds)):

            state, batch_train_metrics = train_step_fn(state, batch, loss_fn, args.parallel)

            if args.loss in ['simsiam', 'directloss']:
                if args.ema:
                    step_num = idx + len(ssl_ds)*(epoch-1)
                    state = state.update_ema(step_num, tot_steps)
                else:
                    state = state.replace(target_params=state.params, tg_batch_stats=state.batch_stats)

            # accumulate metrics for this batch (normalize by number of batches), check if first batch
            if args.parallel:
                batch_train_metrics = unreplicate(batch_train_metrics)

            if idx==0:
                # copy batch_train_metrics to train_metrics but divide by number of batches
                train_metrics = jax.tree_map(lambda x: x/len(ssl_ds), batch_train_metrics)
            else:
                train_metrics = jax.tree_map(lambda x, y: x + y/len(ssl_ds), train_metrics, batch_train_metrics)
        
        if not args.optuna:
            do_offline_lep = (epoch == args.num_epochs) or epoch in [5, 10, 20, 50, 100, 200, 500, 1000, 2000]
            metrics = eval_step(state, lep_ds, test_ds, epoch, num_classes=num_classes, do_offline_lep=do_offline_lep, parallel=args.parallel)

        epoch_time = time.perf_counter() - start_time

        if not args.optuna:
            metrics['epoch'] = epoch
            metrics['epoch_time'] = epoch_time
            state_for_log = unreplicate(state) if args.parallel else state
            metrics['learning_rate'] = get_current_lr(state_for_log, args) 
            if args.pred=='mlp':
                metrics['predictor_learning_rate'] = get_current_lr(state_for_log, args, pred=True)
            metrics = {**metrics, **train_metrics}

            if not args.nolog:
                metrics = jax.tree_map(lambda x: x.item() if isinstance(x, jnp.ndarray) and x.size==1 else x, metrics)
                if args.log_eigvals:
                    eigvals = metrics['eigvals'].tolist()
                    eigvals_proj = metrics['eigvals_proj'].tolist()
                    eigval_table.add_data(*eigvals)
                    eigval_proj_table.add_data(*eigvals_proj)
                del metrics['eigvals'], metrics['eigvals_proj']
                wandb.log(metrics)

            print('\n' + '\t\t'.join(f'{k}: {v:.3f}' for k, v in sorted(metrics.items()) if k not in ['eigvals', 'eigvals_proj']) + '\n')


        if args.save_model:
            to_save = unreplicate(state) if args.parallel else state
            save_state(to_save, args, epoch)
        
    if not args.nolog and args.log_eigvals:
        wandb.log({'eigvals_record': eigval_table, 'eigvals_proj_record': eigval_proj_table})

    return metrics['offline_lep_test_acc']

def optuna_objective(trial, args):
    assert args.optimize_lr or args.optimize_decorr, 'No hyperparameters were chosen to optimize\nUse --optimize_lr, --optimize_push or --optimize_decorr flags'
    if args.optimize_lr:
        args.lr = trial.suggest_float('lr', args.optuna_lr_min, args.optuna_lr_max)
    if args.optimize_decorr:
        args.decorr = trial.suggest_float('decorr_coeff', args.optuna_decorr_min, args.optuna_decorr_max)
    if args.optimize_push:
        args.push = trial.suggest_float('push_coeff', args.optuna_push_min, args.optuna_push_max)
    return main(args)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='SSL training experiment')
    parser.add_argument('--nolog', action='store_true', help='stop log to wandb')
    parser.add_argument('--optuna', action='store_true', help='use optuna for hyperparameter search (currently supported for learning rate)')
    parser.add_argument('--log_eigvals', action='store_true', help='log eigvals to wandb')
    parser.add_argument('--dataset', choices=['cifar10', 'cifar100', 'stl10', 'imagenet', 'tinyimagenet', 'imagenet32'], default='cifar10', help='dataset')
    parser.add_argument('--experiment_name', type=str, default='no particular experiment', help='dataset')
    parser.add_argument('--encoder', choices=['resnet18', 'resnet50', 'vgg'], default='resnet18', help='encoder type')
    parser.add_argument('--proj', choices=['mlp', 'id'], default='mlp', help='projection head type')
    parser.add_argument('--pred', choices=['mlp', 'id', 'dp'], default='dp', help='prediction head type')
    parser.add_argument('--device', type=int, default=0, help='GPU device number')
    parser.add_argument('--parallel', action='store_true', help='parallel training')
    parser.add_argument('--bsz', type=int, default=512, help='batch size')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--num_epochs', type=int, default=800, help='number of epochs')
    parser.add_argument('--warmup_epochs', type=int, default=10, help='warmup epochs')
    parser.add_argument('--ema', action='store_true', help='use exponential moving average')
    parser.add_argument('--greedy', action='store_true', help='greedy layer wise training')
    parser.add_argument('--loss', choices=['supervised', 'simsiam', 'directloss', 'simclr', 'vicreg'], default='directloss', help='loss function')
    parser.add_argument('--iso', action='store_true', help='use isotropic (direct) loss')
    parser.add_argument('--iso2', action='store_true', help='use fully isotropic (direct) loss')
    parser.add_argument('--pull_coeff', type=float, default=1., help='pull loss coefficient')
    
    parser.add_argument('--push_coeff', type=float, default=1., help='push loss coefficient')
    parser.add_argument('--optimize_push', action='store_true', help='use optuna to optimize push coefficient')
    parser.add_argument('--optuna_push_min', type=float, default=0.01, help='lowest push coefficient for optuna')
    parser.add_argument('--optuna_push_max', type=float, default=1000., help='highest push coefficient for optuna')

    parser.add_argument('--decorr_coeff', type=float, default=100., help='decorr loss coefficient')
    parser.add_argument('--optimize_decorr', action='store_true', help='use optuna to optimize decorr coefficient')
    parser.add_argument('--optuna_decorr_min', type=float, default=0.01, help='lowest decorr coefficient for optuna')
    parser.add_argument('--optuna_decorr_max', type=float, default=1000., help='highest decorr coefficient for optuna')

    parser.add_argument('--distance_metric', choices=['l2', 'cosine', 'pseudo_cosine'], default='cosine', help='distance metric')
    parser.add_argument('--opt', choices=['adamw', 'lars', 'sgd'], default='sgd', help='optimizer')
    parser.add_argument('--cosine_lr_decay', action='store_true', help='use a cosine schedule for lr decay, else constant')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--pred_lr_coeff', type=float, default=10., help='prediction head learning rate multiplier')
    parser.add_argument('--optimize_lr', action='store_true', help='use optuna to optimize learning rate')
    parser.add_argument('--optuna_lr_min', type=float, default=0.0001, help='lowest learning rate for optuna')
    parser.add_argument('--optuna_lr_max', type=float, default=0.1, help='highest learning rate for optuna')
    parser.add_argument('--wd', type=float, default=4e-4, help='weight decay')
    parser.add_argument('--proj_layers', type=str, default='2048,2048', help='proj mlp layers')

    parser.add_argument('--dp_update_freq', type=int, default=1, help='frequency of Wp updates')
    parser.add_argument('--dp_alpha', type=float, default=0.5, help='power of EVs')
    parser.add_argument('--dp_tau', type=float, default=0.3, help='ema coeff for DP')
    parser.add_argument('--dp_normalize', type=bool, default=False, help='normalize the eigenvalues')
    parser.add_argument('--dp_pc_num_components', type=int, default=-1, help='number of components to keep')
    parser.add_argument('--dp_pc_thresh', type=float, default=0., help='threshold for eigenvalues')
    parser.add_argument('--dp_pc_cutoff_method', choices=['dim', 'thresh', 'none'], default='none', help='cutoff method for eigenvalues')
    parser.add_argument('--dp_eigval_floor', type=float, default=0., help='floor for eigenvalues')

    parser.add_argument('--save_model', action='store_true')
    parser.add_argument('--load_state', action='store_true')
    parser.add_argument('--model_name', type=str, default='')
    parser.add_argument('--load_epoch', type=int, default=0)

    args = parser.parse_args()

    if args.optuna:
        assert args.nolog, 'won\'t log to wandb when using optuna'
        # Add stream handler of stdout to show the messages
        optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
        out_file = 'optuna/{}_{}.log'.format(args.dataset, args.experiment_name)
        optuna.logging.enable_propagation()
        optuna.logging.enable_default_handler()
        optuna.logging.get_logger("optuna").handlers[0].stream = open(out_file, 'w')
        storage_name = "mysql://{sql_user}:{sql_password}@{sql_host}:{sql_port}/{sql_db}?charset=utf8mb4}"

        study = optuna.create_study(direction='maximize', study_name=args.experiment_name, storage=storage_name, load_if_exists=True)
        study.optimize(lambda trial: optuna_objective(trial, args), n_trials=20)
        print('Best trial:')
        trial = study.best_trial
        print('  Value: {}'.format(trial.value))
        print('  Params: ')
        for key, value in trial.params.items():
            print('    {}: {}'.format(key, value))
        
        # save best trial to file
        with open('optuna/{}_{}.txt'.format(args.dataset, args.experiment_name), 'w') as f:
            f.write('Best trial:\n')
            f.write('  Value: {}\n'.format(trial.value))
            f.write('  Params: \n')
            for key, value in trial.params.items():
                f.write('    {}: {}\n'.format(key, value))
    else:
        main(args)
        