from __future__ import print_function
from textwrap import indent
from turtle import color
from matplotlib.pyplot import axis
from numpy.lib.function_base import append
import random
import torch,pickle
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import config as cf
from datasets import ImagenetNoise

import torchvision.transforms as transforms

import os
import argparse
import json
from PIL import Image
import numpy as np


from PIL import Image
import matplotlib.pyplot as plt


from utils import get_pretrained_model, check_dir, prepare_dset, maha, \
    get_maha_distance, get_maha_distance_cov, get_relative_maha_distance, MahaDistNormalizer, ranking_loss
from networks import *
from torch.autograd import Variable
from torch.nn.functional import one_hot, softmax
import torchvision
from metrics.ood_metrics import OOD_METRICS


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
    #  torch.backends.cudnn.deterministic = True
setup_seed(20)

parser = argparse.ArgumentParser(description='Ensemble Training')
# pretrained models setting
parser.add_argument('--maha_file', default='./ssl/maha_dict.npy', type=str)
parser.add_argument('--maha_file_m0', default='./ssl/maha_dict.npy', type=str)
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50')
parser.add_argument('--pretrained', default='', type=str,
                    help='path to moco pretrained checkpoint')
parser.add_argument('--pretrained_model', default='vit', type=str, help='SSL feature map type')
parser.add_argument('--comp_dis', action='store_true', default=False)

parser.add_argument('--model_path', default='', type=str, help='Model trained with in_dataset')
parser.add_argument('--args_path', default='', type=str, help='Arguments for training model')

parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--dataset', default='cifar100', type=str, help='cifar10/cifar100')
parser.add_argument('--num_classes', default=100, type=int)


parser.add_argument('--ynoise_type', default='symmetric', type=str, help='symmetric/pairflip')
parser.add_argument('--ynoise_rate', default=0.0, type=float, help='label noise rate')
parser.add_argument('--xnoise_type', default='blur', type=str, help='gaussian/blur')
parser.add_argument('--xnoise_arg', default=1, type=float)
parser.add_argument('--xnoise_rate', default=0.0, type=float)
parser.add_argument('--trigger_size', type=int, default=3)
parser.add_argument('--trigger_ratio', type=float, default=0.)


parser.add_argument('--random_state', type=int, default=0)
args = parser.parse_args()

print(args)
# os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
best_acc = 0
batch_size, optim_type = args.batch_size, cf.optim_type

if args.dataset != 'imagenet':
    trainset, testset, trainvalset = prepare_dset(args)
    num_classes = trainset.nb_classes
else:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    trainset = ImagenetNoise(
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
        xnoise_rate=args.xnoise_rate,
        xnoise_arg=args.xnoise_arg,
        xnoise_type=args.xnoise_type,
        ynoise_type=args.ynoise_type,
        ynoise_rate=args.ynoise_rate,
        random_state=args.random_state,
        num_classes=args.num_classes
    )
    num_classes = args.num_classes
    testset = ImagenetNoise(
        train=False,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]),
        num_classes=args.num_classes
    )

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,num_workers=4)
# metricloader = torch.utils.data.DataLoader(trainvalset, batch_size=1000, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)


print('loading the checkpoint')
# Load Model
if args.dataset != 'imagenet':
    model_path = args.model_path
    checkpoint = torch.load(model_path)
    net = checkpoint['net'].cuda()
    print('saving acc:', checkpoint['acc'])
else:
    net = torchvision.models.resnet34(pretrained=False, num_classes=1000).cuda()
    model_path = args.model_path
    checkpoint = torch.load(model_path)
    net.load_state_dict(checkpoint['net'])
    print('saving acc:', checkpoint['acc'])
# Load Model
# checkpoint = torch.load(args.model_path)
# args_load = pickle.load(open(args.args_path, 'rb'))
# net = checkpoint['net'].cuda()
# net = torchvision.models.resnet34(pretrained=True).cuda()

# Load pretrain Model
# pretrain_model = get_pretrained_model(args)
# pretrain_model.cuda()
# pretrain_model = torch.nn.DataParallel(pretrain_model)
# pretrain_model.eval()
net = torch.nn.DataParallel(net)
cudnn.benchmark = True

# # load parameters of Gaussian distributed
# maha_intermediate_dict = np.load(args.maha_file, allow_pickle='TRUE')
# # m0_maha_dict = np.load(args.maha_file_m0, allow_pickle='TRUE')
# class_cov_invs = maha_intermediate_dict.item()['class_cov_invs']
# class_means = maha_intermediate_dict.item()['class_means']
# cov_invs = maha_intermediate_dict.item()['cov_inv']
# means = maha_intermediate_dict.item()['mean']
# # cov_invs = m0_maha_dict.item()['cov_inv']
# # means = m0_maha_dict.item()['mean']





# up_sample = nn.Upsample(size=(224,224), mode='bilinear')
# file_name = '/cifar100/cifar100_'
# file_name = '/'+str(args.dataset)+'/'+str(args.dataset)+'_'
# def evaluate_maha_mis():
#     maha_dis_list_succ = []
#     maha_dis_list_err = []
#     maha_dis_list = []
#     corr_err_list = []
#     conf_list = []
#     # pretrain_model.eval()
#     net.eval()
#     with torch.no_grad():
#         # for batch_idx, ((inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
#         # for batch_idx, (id,(inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
#         # for batch_idx, (id,inputs,targets) in enumerate(testloader):
#         for batch_idx, (inputs,targets) in enumerate(testloader):
#         # for batch_idx, (inputs, targets) in enumerate(trainloader):
#             if args.dataset != 'imagenet':
#                 pretrain_inputs = up_sample(inputs)
#             else:
#                 pretrain_inputs = inputs.cuda()
#             if use_cuda:
#                 inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
            
#             # pre_feature = pretrain_model(pretrain_inputs)
#             inputs, targets = Variable(inputs), Variable(targets)
#             outputs = net(inputs)
#             if args.arch.startswith('hug'):
#                 pre_feature = pre_feature.logits.cpu().data.numpy()
#             else:
#                 pre_feature = pre_feature.cpu().data.numpy()
#             # maha_distance = get_maha_distance(pre_feature,class_cov_invs, class_means, targets.cpu().data.numpy())
#             # maha_distance = get_maha_distance_cov(pre_feature,cov_invs, class_means, targets.cpu().data.numpy())
#             maha_distance = get_relative_maha_distance(pre_feature,cov_invs, class_cov_invs, means, class_means, targets.cpu().data.numpy())
#             maha_dis_list.append(maha_distance)

#             if args_load.ensemble_num > 1:
#                 _, predicted = torch.max(F.softmax(torch.stack(outputs),dim=-1).mean(dim=0).data, 1)
#             else:
#                 conf, predicted = torch.max(F.softmax(outputs.data), 1)
#             # print(predicted.cpu().data.numpy(),targets.cpu().data.numpy())
#             correct_index = (predicted.cpu().data.numpy() == targets.cpu().data.numpy())
#             # print(correct_index)
#             corr_err_list.append(correct_index)
#             conf_list.append(conf.cpu().data.numpy())

#     # print(corr_err_list)
#     maha_dis_list = np.concatenate(maha_dis_list,axis=0).squeeze()
#     corr_err_list = np.concatenate(corr_err_list,axis=0).squeeze()
#     conf_list = np.concatenate(conf_list,axis=0).squeeze()
#     print(maha_dis_list.shape)
#     print(corr_err_list.shape)
#     for i in range(len(maha_dis_list)):
#         # print(i)
#         if corr_err_list[i]:
#             maha_dis_list_succ.append(maha_dis_list[i])
#             # print(i)
#         else:
#             # print('111')
#             maha_dis_list_err.append(maha_dis_list[i])

#     # for i in range(len(conf_list)):
#     #     # print(i)
#     #     if corr_err_list[i]:
#     #         maha_dis_list_succ.append(conf_list[i])
#     #         # print(i)
#     #     else:
#     #         # print('111')
#     #         maha_dis_list_err.append(conf_list[i])
    
#     return maha_dis_list_succ, maha_dis_list_err

def evaluate():
    maha_dis_list_succ = []
    maha_dis_list_err = []
    maha_dis_list_succ_epy = []
    maha_dis_list_err_epy = []
    maha_dis_list = []
    corr_err_list = []
    conf_list = []
    epy_list = []
    # pretrain_model.eval()
    net.eval()
    with torch.no_grad():
        # for batch_idx, ((inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
        # for batch_idx, (id,(inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
        for batch_idx, (id,inputs,targets) in enumerate(testloader):
        # for batch_idx, (inputs,targets) in enumerate(testloader):
        # for batch_idx, (inputs, targets) in enumerate(trainloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
            
            # pre_feature = pretrain_model(pretrain_inputs)
            inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)
            # if args.arch.startswith('hug'):
            #     pre_feature = pre_feature.logits.cpu().data.numpy()
            # else:
            #     pre_feature = pre_feature.cpu().data.numpy()
            # maha_distance = get_maha_distance(pre_feature,class_cov_invs, class_means, targets.cpu().data.numpy())
            # maha_distance = get_maha_distance_cov(pre_feature,cov_invs, class_means, targets.cpu().data.numpy())
            # maha_distance = get_relative_maha_distance(pre_feature,cov_invs, class_cov_invs, means, class_means, targets.cpu().data.numpy())
            # maha_dis_list.append(maha_distance)

            
            conf, predicted = torch.max(F.softmax(outputs.data), 1)
            entropy = torch.sum(F.softmax(outputs.data) * F.log_softmax(outputs.data),dim=1)
            # print(predicted.cpu().data.numpy(),targets.cpu().data.numpy())
            correct_index = (predicted.cpu().data.numpy() == targets.cpu().data.numpy())
            # print(correct_index)
            corr_err_list.append(correct_index)
            conf_list.append(conf.cpu().data.numpy())
            epy_list.append(entropy.cpu().data.numpy())

    # print(corr_err_list)
    # maha_dis_list = np.concatenate(maha_dis_list,axis=0).squeeze()
    corr_err_list = np.concatenate(corr_err_list,axis=0).squeeze()
    conf_list = np.concatenate(conf_list,axis=0).squeeze()
    epy_list = np.concatenate(epy_list,axis=0).squeeze()
    # print(maha_dis_list.shape)
    print(corr_err_list.shape)
    # for i in range(len(maha_dis_list)):
    #     # print(i)
    #     if corr_err_list[i]:
    #         maha_dis_list_succ.append(maha_dis_list[i])
    #         # print(i)
    #     else:
    #         # print('111')
    #         maha_dis_list_err.append(maha_dis_list[i])

    for i in range(len(conf_list)):
        # print(i)
        if corr_err_list[i]:
            maha_dis_list_succ.append(-conf_list[i])
            maha_dis_list_succ_epy.append(-epy_list[i])
            # print(i)
        else:
            # print('111')
            maha_dis_list_err.append(-conf_list[i])
            maha_dis_list_err_epy.append(-epy_list[i])
    maha_dis_list_err1 = np.array(maha_dis_list_err)
    maha_dis_list_succ1 = np.array(maha_dis_list_succ)
    print(np.mean(maha_dis_list_err1),np.mean(maha_dis_list_succ1))

    maha_dis_list_err1 = np.array(maha_dis_list_err_epy)
    maha_dis_list_succ1 = np.array(maha_dis_list_succ_epy)
    print(np.mean(maha_dis_list_err1),np.mean(maha_dis_list_succ1))
    return maha_dis_list_succ, maha_dis_list_err


in_confidence_score, ood_confidence_score = evaluate()
# in_confidence_score = np.concatenate(in_confidence_score)
# ood_confidence_score = np.concatenate(ood_confidence_score)
    

print("Evaluating misclassifction Detection Perfermance...")
scores = np.concatenate((in_confidence_score, ood_confidence_score), axis=0).astype(np.float128)
    
in_labels = np.zeros_like(in_confidence_score)
out_labels = np.ones_like(ood_confidence_score)
domain_labels = np.concatenate((in_labels, out_labels), axis=0)

tpr95_score = OOD_METRICS["tpr95"](domain_labels, scores)
auroc_score = OOD_METRICS["auroc"](domain_labels, scores)
auprIn_score = OOD_METRICS["auprIn"](domain_labels, scores)
auprOut_score = OOD_METRICS["auprOut"](domain_labels, scores)
de_score = OOD_METRICS["detection_err"](domain_labels, scores)
    
print("{:20}{:13.2f}% ".format("FPR at TPR 95%:", tpr95_score*100))
print("{:20}{:13.2f}% ".format("Detection error:", de_score*100))
print("{:20}{:13.2f}% ".format("AUROC:",auroc_score*100))
print("{:20}{:13.2f}% ".format("AUPR In:",auprIn_score*100))
print("{:20}{:13.2f}% ".format("AUPR Out:",auprOut_score*100))
