import os
import time
import argparse
import torch
from torch import nn
from torch.backends import cudnn
import dataset as dataset
import numpy as np
import torchvision
import torch.nn.functional as F
from wideresnet import WideResNet
from lenet import LeNet
import logging
import copy
from partial_models.resnet import ResNet
parser = argparse.ArgumentParser(description='Revisiting Consistency Regularization for Deep Partial Label Learning')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--lam', default=1, type=float)
parser.add_argument('--dataset', type=str, choices=['svhn', 'cifar10', 'cifar100', 'fmnist', 'kmnist','cifar100H','cub200'],
                    default='cifar10')
parser.add_argument('--model', type=str,  default='widenet')
parser.add_argument('--lr', default=0.1, type=float)
parser.add_argument('--rate', default=0.4, type=float,help='-1 for feature, 0.x for random')
parser.add_argument('--noise_rate', default=0.0, type=float,help='noise rate')
parser.add_argument('--bs', default=64, type=int)
parser.add_argument('--trial', default='1', type=str)
parser.add_argument('--data-dir', default='./data/', type=str)
parser.add_argument('--augment', default='DPLL', type=str)
parser.add_argument('--gpu', default=2, type=str)
parser.add_argument('--seed', help='seed', type=int, default=0)
parser.add_argument('--wd', default=1e-4, type=float)

parser.add_argument('--sharpen', default=1, type=float)
parser.add_argument('--piror_start', default=2000, type=float)
parser.add_argument('--piror', default=0, type=float)
parser.add_argument('--piror_start_auto', action='store_true', default=False, help='whether auto select correct_start')
parser.add_argument('--piror_auto', default='case1', type=str, help = 'for case 3')
parser.add_argument('--piror_add', default=0, type=float, help = 'for case 3')
parser.add_argument('--piror_max', default=1, type=float, help = 'for case 3')
parser.add_argument('--c', default=0.8, type=float)
parser.add_argument('--onehot', default=True, type=bool)

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] =args.gpu
args.partial_rate = args.rate
args.batch_size = args.bs
args.workers = 16
args.augment_type = 'case1'
cudnn.benchmark = True
torch.set_printoptions(precision=2, sci_mode=False)

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)



best_prec1 = 0
num_classes = 10 if args.dataset != 'cifar100' else 100
if args.dataset == 'cifar100H':
    num_classes =100
if args.dataset == 'cub200':
    num_classes = 200
logging.basicConfig(format='[%(asctime)s] - %(message)s',
                    datefmt='%Y/%m/%d %H:%M:%S',
                    level=logging.DEBUG,
                    handlers=[
    logging.StreamHandler()
]
                    #,filename= 'trial_{}_{}_dataset_{}_binomial_{}_lam_{}_loss_{}'.format(args.trial, args.model, args.dataset, args.rate, args.lam,args.loss)+'.log'
                    )

args.model_name = 'trial_{}_{}_dataset_{}_binomial_{}_lam_{}'.format(args.trial, args.model, args.dataset, args.rate, args.lam)
logging.info(args)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

ans = [0]
def DPLL_train(train_loader, model, optimizer, epoch, consistency_criterion, confidence,args):
    """
        Run one train epoch
    """
    data_time = AverageMeter()
    batch_time = AverageMeter()
    losses = AverageMeter()
    end = time.time()

    model.train()
    piror_set = [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]
    classfy_piror_set_bingo_num = [0,0,0,0,0,0,0,0,0,0,0]
    margin = []
    for i, (x_aug0, x_aug1, x_aug2, y, part_y, index) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        # partial label

        part_y = part_y.float().cuda()
        # original samples with pre-processing
        x_aug0 = x_aug0.cuda()
        y_pred_aug0 = model(x_aug0)
        # augmentation1
        x_aug1 = x_aug1.cuda()
        y_pred_aug1 = model(x_aug1)
        # augmentation2
        x_aug2 = x_aug2.cuda()
        y_pred_aug2 = model(x_aug2)

        y_pred_aug0_probas_log = torch.log_softmax(y_pred_aug0, dim=-1)
        y_pred_aug1_probas_log = torch.log_softmax(y_pred_aug1, dim=-1)
        y_pred_aug2_probas_log = torch.log_softmax(y_pred_aug2, dim=-1)

        y_pred_aug0_probas = torch.softmax(y_pred_aug0, dim=-1)
        y_pred_aug1_probas = torch.softmax(y_pred_aug1, dim=-1)
        y_pred_aug2_probas = torch.softmax(y_pred_aug2, dim=-1)
        # consist loss

        consist_loss0 = consistency_criterion(y_pred_aug0_probas_log, torch.tensor(confidence[index]).float().cuda())
        consist_loss1 = consistency_criterion(y_pred_aug1_probas_log, torch.tensor(confidence[index]).float().cuda())
        consist_loss2 = consistency_criterion(y_pred_aug2_probas_log, torch.tensor(confidence[index]).float().cuda())
        # supervised loss
        super_loss = -torch.mean(torch.sum(torch.log(1.0000001 - F.softmax(y_pred_aug0, dim=1)) * (1 - part_y), dim=1))
        
        # dynamic lam
        lam = min((epoch / 100) * args.lam, args.lam)

        # Unified loss
     
        final_loss = lam * (consist_loss0 + consist_loss1 + consist_loss2) + super_loss

        optimizer.zero_grad()
        final_loss.backward()
        optimizer.step()

        # update confidence

        classfy_out =     torch.pow(y_pred_aug0_probas, 1 / (2 + 1)) \
                        * torch.pow(y_pred_aug1_probas, 1 / (2 + 1)) \
                        * torch.pow(y_pred_aug2_probas, 1 / (2 + 1))
        classfy_out = classfy_out.detach()
        plabels = part_y.clone()
        dlabels = y.cuda()
        for jj in range(len(piror_set)):
            classfy_piror_set_bingo_num[jj] = classfy_piror_set_bingo_num[jj] + torch.eq(torch.max(classfy_out * (plabels +piror_set[jj]*(1-plabels)),1)[1], dlabels).sum().cpu()
        margin += ((torch.max(classfy_out*plabels, 1)[0])/(1e-9+torch.max(classfy_out *(1-plabels), 1)[0])).tolist()
        confidence_update(confidence, y_pred_aug0_probas, y_pred_aug1_probas, y_pred_aug2_probas, part_y, index,args,epoch=epoch)

        losses.update(final_loss.item(), x_aug0.size(0))
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 50 == 0:
            logging.info('Epoch: [{0}][{1}/{2}]\t'
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                         'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                         'lam ({lam})\t'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, lam=lam))


    if epoch>=args.piror_start :
        if args.noise_rate>0:
            if args.piror_auto == 'case1':
                args.piror = sorted(margin)[int(len(margin)*args.noise_rate)]
            else:
                args.piror = min(args.piror+args.piror_add,args.piror_max)
                
        else:
            args.piror = 0
        args.piror = min(args.piror,args.piror_max)
    print(classfy_piror_set_bingo_num,args.piror)
    return losses.avg


    
def confidence_update(confidence, y_pred_aug0_probas, y_pred_aug1_probas, y_pred_aug2_probas, part_y, index,args,epoch=0):
    y_pred_aug0_probas = y_pred_aug0_probas.detach()
    y_pred_aug1_probas = y_pred_aug1_probas.detach()
    y_pred_aug2_probas = y_pred_aug2_probas.detach()
    
    revisedY0 = part_y.clone()
    if epoch>args.piror_start:
        revisedY0 = revisedY0+args.piror*(1-revisedY0)


    revisedY0 = revisedY0 * torch.pow(y_pred_aug0_probas, 1 / (2 + 1)) \
                    * torch.pow(y_pred_aug1_probas, 1 / (2 + 1)) \
                    * torch.pow(y_pred_aug2_probas, 1 / (2 + 1))
    if epoch>args.piror_start:
        if args.onehot == False:
            revisedY0 = torch.pow(revisedY0,args.sharpen)
            revisedY0 = revisedY0 / revisedY0.sum(dim=1).repeat(num_classes, 1).transpose(0, 1)
            confidence[index, :] = args.c*confidence[index, :]+(1-args.c)*revisedY0.cpu().numpy()
        else:
            revisedY0 = torch.pow(revisedY0,10)
            revisedY0 = revisedY0 / revisedY0.sum(dim=1).repeat(num_classes, 1).transpose(0, 1)
            confidence[index, :] = args.c*confidence[index, :]+(1-args.c)*revisedY0.cpu().numpy()
    else:
        revisedY0 = revisedY0 / revisedY0.sum(dim=1).repeat(num_classes, 1).transpose(0, 1)
        confidence[index, :] = revisedY0.cpu().numpy()


def validate(valid_loader, model, criterion, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()

    model.eval()
    with torch.no_grad():
        for i, (input, target) in enumerate(valid_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # if i % 50 == 0:
            #     logging.info('Test: [{0}/{1}]\t'
            #                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            #                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
            #                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
            #         i, len(valid_loader), batch_time=batch_time, loss=losses,top1=top1))

    logging.info(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
    return top1.avg, losses.avg


def DPLL():
    global args, best_prec1
    # load data
    if args.dataset == "cifar10":
        num_classes = 10
        train_loader, test = dataset.cifar10_dataloaders(args.data_dir, args.rate,args=args)
        channel = 3
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, test = dataset.cifar100_dataloaders(args.data_dir, args.rate,args=args)
        channel = 3
    elif args.dataset == 'cifar100H':
        from datasets.cifar100H import load_cifar100H
        args.num_class = 100
        num_classes = 100
        train_loader,train_givenY, test = load_cifar100H(args)
        channel = 3
    elif args.dataset == 'cub200':
        from datasets.cub200 import load_cub200
        args.num_class = 200
        num_classes = 200
        train_loader,train_givenY, test = load_cub200(args)
    else:
        assert "Unknown dataset"

    # load model
    if args.model == 'widenet':
        model = WideResNet(34, num_classes, widen_factor=10, dropRate=0.0)
    if args.model == 'widenet-28-2':
        model = WideResNet(28, num_classes, widen_factor=2, dropRate=0.0)
    elif args.model == 'lenet':
        model = LeNet(out_dim=num_classes, in_channel=1, img_sz=28)
    elif args.model == 'resnet18':
        model = ResNet(depth=20,n_outputs=num_classes)
        if args.dataset == 'cub200':
            from torchvision import models
            model = models.resnet18(pretrained=True) # pretrain via imagenet
            model.fc = nn.Linear(512, 200)
    else:
        assert "Unknown model"
    model = model.cuda()

    # criterion
    criterion = nn.CrossEntropyLoss().cuda()
    consistency_criterion = nn.KLDivLoss(reduction='batchmean').cuda()
    # optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd)
    # scheduler
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150,200,250], last_epoch=-1)

    cudnn.benchmark = True
    # init confidence
    if args.dataset == 'cifar100H' or args.dataset == 'cub200':
        train_givenY = torch.FloatTensor(train_givenY)
        tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1])
        confidence = train_givenY.float()/tempY
        confidence = confidence.cpu().numpy()
    else:
        confidence = copy.deepcopy(train_loader.dataset.partial_labels)
        confidence = confidence / confidence.sum(axis=1)[:, None]

    # Train loop
    best = 0
    for epoch in range(0, args.epochs):
        logging.info('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        # training
        trainloss = DPLL_train(train_loader, model, optimizer, epoch, consistency_criterion, confidence,args)
        # lr_step
        scheduler.step()
        # evaluate on validation set
        valacc, valloss = validate(test, model, criterion, epoch)
if __name__ == '__main__':
    DPLL()
