
from __future__ import print_function
import argparse
import os,sys

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from dataset_mix import get_dataloader_fake_paired_list
from losses import CosFace, IDMMD
from torch import distributed

from light_cnn import LightCNN_9Layers_cosface, LightCNN_29Layers_cosface 

try:
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])
    distributed.init_process_group("nccl")
except KeyError:
    world_size = 1
    rank = 0
    distributed.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:12521",
        rank=rank,
        world_size=world_size,
    )

parser = argparse.ArgumentParser(description='PyTorch Light CNN Training')
parser.add_argument('--arch', '-a', metavar='ARCH', default='LightCNN')
parser.add_argument('--cuda', '-c', default=True)
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 16)')
parser.add_argument('--epochs', default=80, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=100, type=int,
                    metavar='N', help='print frequency (default: 100)')
parser.add_argument('--model', default='', type=str, metavar='Model',
                    help='model type: LightCNN-9, LightCNN-29')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--save_path', default='', type=str, metavar='PATH',
                    help='path to save checkpoint (default: none)')
parser.add_argument('--prefix', default='', type=str,
                    help='prefix of log/model')
parser.add_argument('--input_type', default='grey', type=str, help='prefix of log/model')

parser.add_argument('--vis_pretrain_weight', default='', type=str)
parser.add_argument('--img_root_F', default='', type=str)
parser.add_argument('--train_list_F', default='', type=str)
parser.add_argument('--num_img_per_id', default='', type=int)

def main():
    global args
    args = parser.parse_args()
    seed = 0
    cudnn.benchmark = True
    cudnn.enabled = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    input_type = args.input_type

    train_loader_fake, num_classes = get_dataloader_fake_paired_list(args)

    if args.model == 'LightCNN-9':
        model = LightCNN_9Layers_cosface(num_classes=num_classes)
    elif args.model == 'LightCNN-29':
        model = LightCNN_29Layers_cosface(num_classes=num_classes)
    else:
        print('Error model type\n')

    print("current input type: ", input_type)

    model = torch.nn.DataParallel(model).cuda()

    print("=> loading pretrained vis pre-trained model: ".format(args.vis_pretrain_weight))
    load_model_vis_lightcnn(model, args.vis_pretrain_weight)

    criterion = nn.CrossEntropyLoss().cuda()
    criterion_idmmd = IDMMD().cuda()
    margin_softmax = CosFace(s=64.0, m=0.3).cuda()

    '''
    Stage I: model pretrained for last fc2 parameters
    '''
    params_pretrain = []
    for name, value in model.named_parameters():
        if name == "module.weight":
            params_pretrain += [{"params": value, "lr": 10 * args.lr}]

    print("Stage I: trainable params ", len(params_pretrain))
    assert len(params_pretrain) > 0

    # optimizer
    optimizer_pretrain = torch.optim.SGD(params_pretrain, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    pre_train_epochs = 5
    for epoch in range(1, pre_train_epochs + 1):
        pre_train_pair(train_loader_fake, model, criterion, margin_softmax, optimizer_pretrain, epoch)


    '''
    Stage II: model finetune for full network
    '''
    optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    start_epoch = 0
    for epoch in range(start_epoch, args.epochs + 1):
        adjust_learning_rate(optimizer, epoch)
        train(train_loader_fake, model, criterion, criterion_idmmd, margin_softmax, optimizer, epoch)
        
        if epoch % 5 == 0:
            model_name = "".join(args.model.split('-'))
            save_name = args.save_path + '{}_{}_e{}.pth.tar'.format(args.prefix, model_name, epoch+1) if args.prefix!='' else args.save_path + '{}_e{}.pth.tar'.format(model_name, epoch+1)
            
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
            }, save_name)



def pre_train_pair(train_loader, model, criterion, margin_softmax, optimizer, epoch):
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.train()
    for i, (vis_img, nir_img, vis_label, nir_label) in enumerate(train_loader):

        input = torch.cat((vis_img, nir_img), 0).cuda(non_blocking=True)
        label = torch.cat((vis_label, nir_label), 0).cuda(non_blocking=True)
        batch_size = input.size(0)

        if batch_size < 2*args.batch_size:
            continue

        # forward
        output = model(input)[0]
        output = margin_softmax(output, label)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, label.data, topk=(1, 5))
        top1.update(prec1.item(), batch_size)
        top5.update(prec5.item(), batch_size)

        # print log
        if i % args.print_freq == 0:
            info = "====> Epoch: [{:0>3d}][{:3d}/{:3d}] | ".format(epoch, i, len(train_loader))
            info += "Loss: ce: {:4.3f} | ".format(loss.item())
            info += "Prec@1: {:4.2f} ({:4.2f}) Prec@5: {:4.2f} ({:4.2f})".format(top1.val, top1.avg, top5.val, top5.avg)
            print(info)



def train(train_loader, model, criterion, criterion_idmmd, margin_softmax, optimizer, epoch):
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.train()
    for i, (vis_img, nir_img, vis_label, nir_label) in enumerate(train_loader):

        input = torch.cat((vis_img, nir_img), 0).cuda(non_blocking=True)
        label = torch.cat((vis_label, nir_label), 0).cuda(non_blocking=True)
        batch_size = input.size(0)

        if batch_size < 2*args.batch_size:
            continue
        
        # forward
        output, fc = model(input)
        output = margin_softmax(output, label)
        loss_ce = criterion(output, label)

        num_vis = vis_img.size(0)
        num_nir = nir_img.size(0)
        fc_vis, fc_nir = torch.split(fc, [num_vis, num_nir], dim=0)

        loss_idmmd = criterion_idmmd(fc_vis, fc_nir, label[:vis_img.size(0)])
        loss = loss_ce + 100 * loss_idmmd

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, label.data, topk=(1, 5))
        top1.update(prec1.item(), batch_size)
        top5.update(prec5.item(), batch_size)

        # print log
        if i % args.print_freq == 0:
            info = "====> Epoch: [{:0>3d}][{:3d}/{:3d}] | ".format(epoch, i, len(train_loader))
            info += "Loss_ce: {:4.3f} | ".format(loss_ce.data)
            info += "loss_idmmd: {:4.3f} | ".format(loss_idmmd.data)
            info += "Loss_all: {:4.3f} | ".format(loss.item())
            info += "Prec@1: {:4.2f} ({:4.2f}) Prec@5: {:4.2f} ({:4.2f})".format(top1.val, top1.avg, top5.val, top5.avg)
            print(info)

def save_checkpoint(state, filename):
    torch.save(state, filename)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val   = 0
        self.avg   = 0
        self.sum   = 0
        self.count = 0

    def update(self, val, n=1):
        self.val   = val
        self.sum   += val * n
        self.count += n
        self.avg   = self.sum / self.count

def load_model_vis_lightcnn(model, pretrained):
    weights = torch.load(pretrained)
    pretrained_dict = weights["state_dict"]
    model_dict = model.state_dict()

    excl_layer = 'module.weight' if 'cos' in pretrained or 'cosface' in pretrained else 'fc2'
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and excl_layer not in k}
    
    print("len of params to be loaded: ",len(pretrained_dict))
    model.load_state_dict(pretrained_dict, strict=False)


def adjust_learning_rate(optimizer, epoch):
    scale = 0.457305051927326
    step  = 5
    lr = args.lr * (scale ** (epoch // step))
    print('lr: {}'.format(lr))
    if (epoch != 0) & (epoch % step == 0):
        print('Change lr')
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * scale

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred    = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].contiguous().view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

if __name__ == '__main__':
    main()