import numpy as np
import sys
import os
import pickle
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from models.allconv import AllConvNet
from models.wrn import WideResNet
import time
from skimage.filters import gaussian as gblur
from PIL import Image as PILImage
import random
from sklearn.metrics import det_curve, accuracy_score, roc_auc_score, auc, precision_recall_curve

def set_random_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed) 
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False 

def compute_fnr(out_scores, in_scores, fpr_cutoff=.05):
    in_labels = np.zeros(len(in_scores))
    out_labels = np.ones(len(out_scores))
    y_true = np.concatenate([in_labels, out_labels])
    y_score = np.concatenate([in_scores, out_scores])
    fpr, fnr, thresholds = det_curve(y_true=y_true, y_score=y_score)

    idx = np.argmin(np.abs(fpr - fpr_cutoff))

    fpr_at_fpr_cutoff = fpr[idx]
    fnr_at_fpr_cutoff = fnr[idx]

    if fpr_at_fpr_cutoff > 0.1:
        fnr_at_fpr_cutoff = 1.0

    return fnr_at_fpr_cutoff


def compute_auroc(out_scores, in_scores):
    in_labels = np.zeros(len(in_scores))
    out_labels = np.ones(len(out_scores))
    y_true = np.concatenate([in_labels, out_labels])
    y_score = np.concatenate([in_scores, out_scores])
    auroc = roc_auc_score(y_true=y_true, y_score=y_score)

    return auroc


def compute_aupr(out_scores, in_scores):
    in_labels = np.zeros(len(in_scores))
    out_labels = np.ones(len(out_scores))
    y_true = np.concatenate([in_labels, out_labels])
    y_score = np.concatenate([in_scores, out_scores])
    precision, recall, _ = precision_recall_curve(y_true, y_score)
    aupr = auc(recall, precision)

    return aupr


# go through rigamaroo to do ...utils.display_results import show_performance
if __package__ is None:
    import sys
    from os import path

    sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
    from utils.display_results import show_performance, get_measures, print_measures, print_measures_with_std
    import utils.svhn_loader as svhn
    import utils.lsun_loader as lsun_loader
    from utils.additional_transform import AddSaltPepperNoise, AddGaussianNoise, Addblur

parser = argparse.ArgumentParser(description='Evaluates a CIFAR OOD Detector',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Setup
parser.add_argument('--test_bs', type=int, default=200)
parser.add_argument('--num_to_avg', type=int, default=1, help='Average measures across num_to_avg runs.')
parser.add_argument('--validate', '-v', action='store_true', help='Evaluate performance on validation distributions.')
parser.add_argument('--scoring_function', '--score', type=str, default='maxlogits',
                    choices=['msp', 'energy', 'entropy','maxlogits'], help='Choose architecture.')
parser.add_argument('--method_name', '-m', type=str, default='cifar10_wrn_baseline', help='Method name.')
parser.add_argument('--name', type=str, default='baseline', help='Method name.')
# Loading details
parser.add_argument('--layers', default=40, type=int, help='total number of layers')
parser.add_argument('--widen-factor', default=10, type=int, help='widen factor')
parser.add_argument('--droprate', default=0.3, type=float, help='dropout probability')
parser.add_argument('--load', '-l', type=str, default='./snapshots', help='Checkpoint path to resume / test.')
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--prefetch', type=int, default=4, help='Pre-fetching threads.')
parser.add_argument('--num_classes', default=10, type=int, help='total number of layers')
parser.add_argument('--severity', default=0.5, type=float, help='noise severity')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'],
                    help='Choose between CIFAR-10, CIFAR-100.')
parser.add_argument('--model', type=str, default='wrn',
                    choices=['allconv', 'wrn'], help='Choose architecture.')
parser.add_argument('--seed', type=int, default=1, help='Random seed for reproducibility')
parser.add_argument('--noise', type=int, default=1, help='noise')
args = parser.parse_args()


# mean and standard deviation of channels of CIFAR-10 images
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

train_transform = trn.Compose([trn.RandomHorizontalFlip(), trn.RandomCrop(32, padding=4),
                               trn.ToTensor(), trn.Normalize(mean, std)])
# test_transform = trn.Compose([AddGaussianNoise(amplitude=args.severity*10), trn.ToTensor(), trn.Normalize(mean, std)])

clean_test_transform = trn.Compose([AddGaussianNoise(amplitude=0), trn.ToTensor(), trn.Normalize(mean, std)])

if args.noise == 0:
    test_transform = trn.Compose([AddGaussianNoise(amplitude=args.severity*10), trn.ToTensor(), trn.Normalize(mean, std)])
elif args.noise == 1:
    test_transform = trn.Compose([AddSaltPepperNoise(), trn.ToTensor(), trn.Normalize(mean, std)])
elif args.noise == 2:
    test_transform = trn.Compose([Addblur(0.5,"normal"), trn.ToTensor(), trn.Normalize(mean, std)])
elif args.noise == 3:
    test_transform = trn.Compose([Addblur(0.5,"Gaussian"), trn.ToTensor(), trn.Normalize(mean, std)])
elif args.noise == 4:
    test_transform = trn.Compose([Addblur(0.5,"mean"), trn.ToTensor(), trn.Normalize(mean, std)])


if 'cifar10_' in args.method_name:
    train_data_in = dset.CIFAR10('/datasets/cifar10/', train=True, transform=train_transform)
    test_data = dset.CIFAR10('/datasets/cifar10/', train=False, transform=test_transform)
    clean_test_data = dset.CIFAR10('/datasets/cifar10/', train=False, transform=clean_test_transform)
    args.num_classes = 10
    num_classes = 10

else:
    train_data_in = dset.CIFAR100('/datasets/cifar100/', train=True, transform=train_transform)
    test_data = dset.CIFAR100('/datasets/cifar100/', train=False, transform=test_transform)
    clean_test_data = dset.CIFAR100('/datasets/cifar100/', train=False, transform=clean_test_transform)
    args.num_classes = 100
    num_classes = 100


test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_bs, shuffle=False,
                                          num_workers=args.prefetch, pin_memory=True)
clean_test_loader = torch.utils.data.DataLoader(clean_test_data, batch_size=args.test_bs, shuffle=False,
                                          num_workers=args.prefetch, pin_memory=True)

set_random_seed(args.seed)

# Create model
if 'allconv' in args.method_name:
    net = AllConvNet(num_classes)
elif 'wrn' in args.method_name:
    net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)
else:
    print('Model not exists.')
    exit()

start_epoch = 0

# Restore model
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        if 'baseline' in args.method_name:
            subdir = 'baseline'
        # elif 'oe_tune' in args.method_name:
        #     subdir = 'oe_tune_new/TIN597'
        elif 'oe_tune' in args.method_name:
            subdir = 'oe_tune_new'
        # elif 'energy_tune' in args.method_name:
        #     subdir = 'energy_tune_new'
        elif 'energy_tune' in args.method_name:
            subdir = 'energy_tune_new/TIN597'


        model_name = os.path.join(os.path.join(args.load, subdir), args.dataset + '_' + args.model + str(args.seed) + '_' + args.name + '_epoch_' + str(i) + '.pt')
        # model_name = os.path.join(os.path.join(args.load), args.method_name + '_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            print (model_name)
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break
    if start_epoch == 0:
        assert False, "could not resume"

net.eval()

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    net.cuda()

# cudnn.benchmark = True  # fire on all cylinders

ood_num_examples = len(test_data) // 5
expected_ap = ood_num_examples / (ood_num_examples + len(test_data))

concat = lambda x: np.concatenate(x, axis=0)
to_np = lambda x: x.data.cpu().numpy()

def diff_entropy(alphas):
    alpha0 = torch.sum(alphas, dim=1)
    return torch.sum(
            torch.lgamma(alphas)-(alphas-1)*(torch.digamma(alphas)-torch.digamma(alpha0).unsqueeze(1)),
            dim=1) - torch.lgamma(alpha0)


def get_ood_scores(loader, in_dist=False):
    _score = []
    _right_score = []
    _wrong_score = []
    margin_avg = 0.0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            if batch_idx >= ood_num_examples // args.test_bs and in_dist is False:
                break

            data = data.cuda()
            output = net(data)
            smax = to_np(F.softmax(output, dim=1))

            if args.scoring_function == 'entropy':
                _score.append(to_np((output.mean(1) - torch.logsumexp(output, dim=1))))
            elif args.scoring_function == 'energy':
                _score.append(to_np(- torch.logsumexp(output, dim=1)))
            elif args.scoring_function == 'maxlogits':
                _score.append(-np.max(to_np(output).astype(float), axis=1))
            elif args.scoring_function == 'diff_entropy':
                _score.append(to_np(diff_entropy(F.relu(output)+1)))
            else:
                _score.append(-np.max(smax, axis=1))

            if in_dist:
                preds = np.argmax(smax, axis=1)
                targets = target.numpy().squeeze()
                right_indices = preds == targets
                wrong_indices = np.invert(right_indices)

                if args.scoring_function == 'entropy':
                    _right_score.append(to_np((output.mean(1) - torch.logsumexp(output, dim=1)))[right_indices])
                    _wrong_score.append(to_np((output.mean(1) - torch.logsumexp(output, dim=1)))[wrong_indices])
                elif args.scoring_function == 'energy':
                    _right_score.append(to_np((- torch.logsumexp(output, dim=1)))[right_indices])
                    _wrong_score.append(to_np(( - torch.logsumexp(output, dim=1)))[wrong_indices])
                elif args.scoring_function == 'maxlogits':
                    _right_score.append(-np.max(to_np(output).astype(float), axis=1)[right_indices])
                    _wrong_score.append(-np.max(to_np(output).astype(float), axis=1)[wrong_indices])
                elif args.scoring_function == 'diff_entropy':
                    _right_score.append(to_np(diff_entropy(F.relu(output)+1).float())[right_indices])
                    _wrong_score.append(to_np(diff_entropy(F.relu(output)+1).float())[wrong_indices])
                else:
                    _right_score.append(-np.max(smax[right_indices], axis=1))
                    _wrong_score.append(-np.max(smax[wrong_indices], axis=1))


                topk_values, _ = torch.topk(F.softmax(output, dim=1), k=2, dim=1)
                margin = topk_values[:, 0] - topk_values[:, 1]
                margin_avg += margin.sum().item()
                

    if in_dist:
        print('Margin {:.2f}'.format(margin_avg/len(loader.dataset)))
        margin = margin_avg/len(loader.dataset)
        return concat(_score).copy(), concat(_right_score).copy(), concat(_wrong_score).copy(), margin
    else:
        return concat(_score)[:ood_num_examples].copy()


noise_in_score, noise_right_score, noise_wrong_score, noise_margin = get_ood_scores(test_loader, in_dist=True)

num_right = len(noise_right_score)
num_wrong = len(noise_wrong_score)
print('Noise Error Rate {:.2f}'.format(100 * num_wrong / (num_wrong + num_right)))
noiseErrorRate = 100 * num_wrong / (num_wrong + num_right)


in_score, right_score, wrong_score, margin = get_ood_scores(clean_test_loader, in_dist=True)

num_right = len(right_score)
num_wrong = len(wrong_score)
print('Error Rate {:.2f}'.format(100 * num_wrong / (num_wrong + num_right)))
ErrorRate = 100 * num_wrong / (num_wrong + num_right)

# /////////////// End Detection Prelims ///////////////

print('\nUsing CIFAR-10 as typical data') if num_classes == 10 else print('\nUsing CIFAR-100 as typical data')

# /////////////// Error Detection ///////////////

print('\n\nError Detection')
IDauroc, IDaupr, IDfpr = show_performance(wrong_score, right_score, method_name=args.method_name)

# /////////////// Noise Error Detection ///////////////

print('\n\nNoise Error Detection')
noiseIDauroc, noiseIDaupr, noiseIDfpr = show_performance(noise_wrong_score, noise_right_score, method_name=args.method_name)

# /////////////// OOD Detection ///////////////
auroc_list, aupr_list, fpr_list = [], [], []
noiseauroc_list, noiseaupr_list, noisefpr_list = [], [], [] 

def get_and_print_results(ood_loader, num_to_avg=args.num_to_avg):

    aurocs, auprs, fprs = [], [], []
    noiseaurocs, noiseauprs, noisefprs = [], [], []
    for _ in range(num_to_avg):
        out_score = get_ood_scores(ood_loader)
        print ("out_score",out_score)
        # measures = get_measures(out_score, in_score)
        # aurocs.append(measures[0]); auprs.append(measures[1]); fprs.append(measures[2])
        aurocs.append(compute_auroc(out_score,in_score))
        auprs.append(compute_aupr(out_score,in_score))
        fprs.append(compute_fnr(out_score,in_score))
        # noisemeasures = get_measures(out_score, noise_in_score)
        # noiseaurocs.append(noisemeasures[0]); noiseauprs.append(noisemeasures[1]); noisefprs.append(noisemeasures[2])
        noiseaurocs.append(compute_auroc(out_score,noise_in_score))
        noiseauprs.append(compute_aupr(out_score,noise_in_score))
        noisefprs.append(compute_fnr(out_score,noise_in_score))

    auroc = np.mean(aurocs); aupr = np.mean(auprs); fpr = np.mean(fprs)
    auroc_list.append(auroc); aupr_list.append(aupr); fpr_list.append(fpr)
    noiseauroc = np.mean(noiseaurocs); noiseaupr = np.mean(noiseauprs); noisefpr = np.mean(noisefprs)
    noiseauroc_list.append(noiseauroc); noiseaupr_list.append(noiseaupr); noisefpr_list.append(noisefpr)

    if num_to_avg >= 5:
        print_measures_with_std(aurocs, auprs, fprs, args.method_name)
    else:
        print_measures(auroc, aupr, fpr, args.method_name)
        print_measures(noiseauroc, noiseaupr, noisefpr, args.method_name)

    print('{} samples in total.'.format(len(ood_loader.dataset)))


# /////////////// Textures ///////////////

ood_data = dset.ImageFolder(root="/dataset/dtd/images",
                            transform=trn.Compose([trn.Resize(32), trn.CenterCrop(32),
                                                   trn.ToTensor(), trn.Normalize(mean, std)]))
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

print('\n\nTexture Detection')
get_and_print_results(ood_loader)

# /////////////// SVHN ///////////////

ood_data = svhn.SVHN(root='/dataset/SVHN', split="test",
                     transform=trn.Compose([trn.Resize(32), trn.ToTensor(), trn.Normalize(mean, std)]), download=False)
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

print('\n\nSVHN Detection')
get_and_print_results(ood_loader)

# /////////////// Places365 ///////////////

ood_data = dset.ImageFolder(root="/dataset/place365",
                            transform=trn.Compose([trn.Resize(32), trn.CenterCrop(32),
                                                   trn.ToTensor(), trn.Normalize(mean, std)]))
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

print('\n\nPlaces365 Detection')
get_and_print_results(ood_loader)

# /////////////// LSUN ///////////////

ood_data = dset.ImageFolder(root = "/dataset/LSUN",
                            transform=trn.Compose([trn.Resize(32), trn.CenterCrop(32),
                                                   trn.ToTensor(), trn.Normalize(mean, std)]))
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

print('\n\nLSUN Detection')
get_and_print_results(ood_loader)

# /////////////// LSUN-CROP ///////////////

ood_data = dset.ImageFolder(root = "/dataset/LSUN_resize/",
                            transform=trn.Compose([trn.Resize(32), trn.CenterCrop(32),
                                                   trn.ToTensor(), trn.Normalize(mean, std)]))
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)


print('\n\nLSUN-resize Detection')
get_and_print_results(ood_loader)

# ////////////// iSUN ///////////////
ood_data = dset.ImageFolder(root="/dataset/iSUN",
                            transform=trn.Compose([trn.Resize(32), trn.CenterCrop(32),
                                                   trn.ToTensor(), trn.Normalize(mean, std)]))
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.test_bs, shuffle=True,
                                         num_workers=args.prefetch, pin_memory=True)

print('\n\niSUN Detection')
get_and_print_results(ood_loader)

# /////////////// Mean Results ///////////////

print('\n\nMean Test Results')
auroc = np.mean(auroc_list)
aupr = np.mean(aupr_list)
fpr = np.mean(fpr_list)
noiseauroc = np.mean(noiseauroc_list)
noiseaupr = np.mean(noiseaupr_list)
noisefpr = np.mean(noisefpr_list)
recall_level = 0.95
print('\t\t\t\t' + args.method_name)
print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr))
print('AUROC: \t\t\t{:.2f}'.format(100 * auroc))
print('AUPR:  \t\t\t{:.2f}'.format(100 * aupr))
meanauroc = 100 * auroc
meanaupr = 100 * aupr
meanfpr = 100 * fpr

print('\t\t\t\t' + args.method_name)
print('noiseFPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * noisefpr))
print('noiseAUROC: \t\t\t{:.2f}'.format(100 * noiseauroc))
print('noiseAUPR:  \t\t\t{:.2f}'.format(100 * noiseaupr))
noisemeanauroc = 100 * noiseauroc
noisemeanaupr = 100 * noiseaupr
noisemeanfpr = 100 * noisefpr
OODAcc=100-noiseErrorRate
IDAcc=100-ErrorRate
with open(os.path.join(os.path.join(args.load, subdir), args.method_name + args.scoring_function +
                                  'test_noise_{noise}.csv').format(noise=args.noise), 'a') as f:
        f.write('%01d,%0.2f,%0.2f,%0.2f,%0.2f,%0.2f\n' % (
            args.seed,
            OODAcc,
            IDAcc,
            meanauroc,
            meanaupr,
            meanfpr,
        ))
