import argparse
import datetime
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
import torch
import time
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import math
import logging
import sys
import timm
assert timm.__version__ == "0.3.2"  # version check
import timm.optim.optim_factory as optim_factory
import util.misc as misc
import pandas as pd
import warnings
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.misc import seed_torch
from util.lr_sched import adjust_learning_rate
from util.loss import L2RankLoss
from data.LSDataset import LSDataset_finetune
from data.SJTUDataset import SJTUDataset
from data.WPCDataset import WPCDataset
from model.DisPA import DisPA
from util.logistic_4_fitting import logistic_4_fitting

warnings.filterwarnings('ignore')

def get_args_parser():
    parser = argparse.ArgumentParser('training', add_help=False)
    parser.add_argument('--dataset', type=str, default='sjtu', help='dataset with mos')
    parser.add_argument('--fold', type=int, default=0)
    parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size')
    parser.add_argument('--epoch', type=int, default=30, help='number of epochs')
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
    parser.add_argument('--lr_mi', type=float, default=1e-4, help='learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.05)
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
    # Model parameters
    
    parser.add_argument('--lamda', default=0.01, type=float)
    parser.add_argument('--vit', default='model/checkpoints/checkpoint_0.pth', type=str)
    parser.add_argument('--swin', default='model/checkpoints/swin_tiny_patch4_window7_224.pth', type=str)
    parser.add_argument('--num_mi_update', default=5, type=int)
    
    return parser


def main(args):
    seed_torch(123)
    # torch.set_num_threads(64)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    experiment_dir = Path('./experiment')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(str(experiment_dir) + '/' + str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('log/')
    log_dir.mkdir(exist_ok=True)
    output_dir = file_dir.joinpath('output/')
    output_dir.mkdir(exist_ok=True)

    '''LOG'''
    logger = logging.getLogger('DisPA_train')
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(str(log_dir) + '/' + 'log.txt')
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info('Parameters...')
    logger.info(args)
    print('Parameters...\n', args)
    
    '''DATA LOADING'''
    logger.info('Loading training data...')
    print('Loading training data...')
    assert args.dataset in ['ls', 'sjtu', 'wpc']
    if args.dataset == 'ls':
        trainDataset = LSDataset_finetune(mode='train',fold=args.fold)
        testDataset = LSDataset_finetune(mode='test',fold=args.fold)
    elif args.dataset == 'sjtu':
        trainDataset = SJTUDataset(mode='train', fold=args.fold)
        testDataset = SJTUDataset(mode='test', fold=args.fold)
    else:
        trainDataset = WPCDataset(mode='train', fold=args.fold)
        testDataset = WPCDataset(mode='test', fold=args.fold)
    
    trainDataloader = DataLoader(trainDataset, batch_size=args.batch_size, shuffle=True,
                                 num_workers=args.num_workers,pin_memory=True)
    testDataloader = DataLoader(testDataset, batch_size=args.batch_size, shuffle=False, 
                                num_workers=args.num_workers,pin_memory=True)
    '''MODEL'''
    model = DisPA(swin_checkpoint_path=args.swin, vit_checkpoint_path=args.vit)
    # model = torch.nn.DataParallel(model,device_ids=list(range(torch.cuda.device_count())))
    model = model.cuda()

    eff_batch_size = args.batch_size * args.accum_iter
    eff_lr = args.lr * args.batch_size / eff_batch_size

    # following timm: set wd as 0 for bias and norm layers
    param_groups = optim_factory.add_weight_decay(model, args.weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=eff_lr, betas=(0.9, 0.95))
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    criterion = L2RankLoss()
    optimizer_mi = torch.optim.Adam(model.miestimator.parameters(), lr=args.lr_mi)
    
    '''PRETRAIN'''
    print('Start training...'); logger.info('Start training...')
    for epoch in range(args.epoch):
        model.train()
        num_iter = len(trainDataloader)
        average_loss, average_mi, average_mi_loss = 0, 0, 0
        for data_iter_step, (imgs,frags,mos) in tqdm(enumerate(trainDataloader), total=num_iter, leave=False):

            imgs, frags, mos = imgs.cuda(), frags.cuda(), mos.cuda()
            
            ## updating MI estimator weights
            for _ in range(args.num_mi_update):
                mi_loss = model(imgs, frags, opt='optimizing_estimators')
                optimizer_mi.zero_grad()
                mi_loss.backward()
                optimizer_mi.step()

                average_mi_loss += (mi_loss.mean().item()/5)/num_iter

            estimate_mi, pred_mos = model(imgs, frags, opt='optimizing_networks')
            loss = criterion(pred_mos, mos) + args.lamda * estimate_mi
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            average_loss += (loss.mean().item() / num_iter)
            average_mi += (estimate_mi.mean().item()) / num_iter
        
        scheduler.step()

        if epoch % 5 == 0 or epoch + 5 >= args.epoch:
            misc.save_model(args=args, checkpoint_dir=checkpoints_dir, model=model, 
                            optimizer=optimizer, epoch=epoch)
        
        logger.info('Epoch:{} lr:{:.6f} train loss:{:.6f} MI:{:.6f} MI loss:{:.6f}'
                    .format(epoch, optimizer.param_groups[0]["lr"], average_loss, average_mi, average_mi_loss))
        print('Epoch:{} lr:{:.6f} train loss:{:.6f} MI:{:.6f} MI loss:{:.6f}'
                    .format(epoch, optimizer.param_groups[0]["lr"], average_loss, average_mi, average_mi_loss))
        
        '''TEST'''
        with torch.no_grad():
            model.eval()
            average_loss, average_mi = 0, 0
            num_iter = len(testDataloader)
            pred_mos_list, mos_list = [], []
            for data_iter_step, (imgs,frags, mos) in tqdm(enumerate(testDataloader),total=num_iter,leave=False):
                imgs, frags, mos = imgs.cuda(), frags.cuda(), mos.cuda()
                B, num_view, C, H, W = imgs.shape
                estimate_mi, pred_mos = model(imgs, frags, opt='inference')

                loss = criterion(pred_mos, mos) + args.lamda * estimate_mi
                average_loss += (loss.mean().item() / num_iter)
                average_mi += (estimate_mi.mean().item()) / num_iter
                
                pred_mos = pred_mos.data.cpu().view_as(mos).numpy()
                mos = mos.data.cpu().numpy()
                pred_mos_list.extend(list(pred_mos))
                mos_list.extend(list(mos))

            _, __, pred_mos_list = logistic_4_fitting(pred_mos_list, mos_list)
            pred_mos_series, mos_series = pd.Series(pred_mos_list), pd.Series(mos_list)
            srocc = pred_mos_series.corr(mos_series, method="spearman")
            plcc = pred_mos_series.corr(mos_series, method="pearson")
            krcc = pred_mos_series.corr(mos_series, method="kendall")
            rmse = ((pred_mos_series-mos_series)**2).mean() ** .5
            logger.info('Epoch:{} test loss: {:.3f} SROCC: {:.3f} PLCC: {:.3f} KRCC: {:.3f} RMSE: {:.3f} MI: {:.3f}'
                        .format(epoch,average_loss,srocc,plcc,krcc,rmse,average_mi))
            print('Epoch:{} test loss: {:.3f} SROCC: {:.3f} PLCC: {:.3f} KRCC: {:.3f} RMSE: {:.3f}\n'
                  .format(epoch,average_loss,srocc,plcc,krcc,rmse,average_mi))
    
    print('models in {}'.format(str(file_dir)))
    
if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    main(args)
    print(args)
