import argparse
import os
import time
import torch
import torch.backends.cudnn as cudnn
import torchvision.datasets as dset
import torchvision.transforms as trn
import numpy as np
from tools.datasets import ImageNet
from tools.model_utils import make_and_restore_model, eval_model
from tqdm import tqdm
from model_eval import load_cached_results, cache_results
from torchvision import models

def main(args):
    # /////////////// Model Setup ///////////////

    train_ds = ImageNet('/tmp')
    root = '/REDACTED/dcr_models/pretrained-robust/'
    mkey = f'{args.model_arch}_{args.adv_train_norm}_eps{args.adv_train_eps}'
    args.model_arch = models.wide_resnet50_2() if 'wide' in args.model_arch else args.model_arch
    net, _ = make_and_restore_model(arch=args.model_arch, dataset=train_ds, resume_path=f'{root}/{mkey}.ckpt')
    args.prefetch = 16
    args.test_bs = 128

    if args.ngpu > 1:
        net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    if args.ngpu > 0:
        net.cuda()

    torch.manual_seed(1)
    np.random.seed(1)
    if args.ngpu > 0:
        torch.cuda.manual_seed(1)

    net.eval()
    cudnn.benchmark = True  # fire on all cylinders

    print('Model Loaded')

    # /////////////// Data Loader ///////////////

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    # /////////////// Further Setup ///////////////

    def auc(errs):  # area under the distortion-error curve
        area = 0
        for i in range(1, len(errs)):
            area += (errs[i] + errs[i - 1]) / 2
        area /= len(errs) - 1
        return area


    # correct = 0
    # for batch_idx, (data, target) in enumerate(clean_loader):
    #     data = V(data.cuda(), volatile=True)
    #
    #     output = net(data)
    #
    #     pred = output.data.max(1)[1]
    #     correct += pred.eq(target.cuda()).sum()
    #
    # clean_error = 1 - correct / len(clean_loader.dataset)
    # print('Clean dataset error (%): {:.2f}'.format(100 * clean_error))


    def show_performance(distortion):
        errs = []
        distortion_type, distortion_name = distortion

        for severity in range(1, 6):
            data_dir = f'/REDACTED/data/imagenet-c/{distortion_type}/{distortion_name}/{severity}'
            if not os.path.exists(data_dir):
                print('Missing ', data_dir)
                continue
            distorted_dataset = dset.ImageFolder(
                root=data_dir,
                transform=trn.Compose([trn.CenterCrop(224), trn.ToTensor()]))#, trn.Normalize(mean, std)]))

            distorted_dataset_loader = torch.utils.data.DataLoader(
                distorted_dataset, batch_size=args.test_bs, shuffle=False, num_workers=args.prefetch, pin_memory=True)

            correct = 0
            ctr = 0
            for (data, target) in distorted_dataset_loader:
                data = data.cuda()

                with torch.no_grad():
                    output = net(data)

                pred = output.data.max(1)[1]
                correct += pred.eq(target.cuda()).sum().item()
                # if ctr > 1000:
                #     break
                # ctr += pred.shape[0]

            errs.append(1 - 1.*correct / len(distorted_dataset))
            # errs.append(1 - 1.*correct / ctr)

        print('\n=Average', tuple(errs))
        return np.mean(errs)


    # /////////////// End Further Setup ///////////////


    # /////////////// Display Results ///////////////
    import collections

    print('\nUsing ImageNet data')

    # distortions = [
    #     'gaussian_noise', 'shot_noise', 'impulse_noise',
    #     'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur',
    #     'snow', 'frost', 'fog', 'brightness',
    #     'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression',
    #     'speckle_noise', 'gaussian_blur', 'spatter', 'saturate'
    # ]

    blur = ['defocus_blur', 'glass_blur', 'motion_blur']
    noise = ['gaussian_noise', 'shot_noise', 'impulse_noise']
    digital = ['contrast', 'elastic_transform', 'pixelate']
    distortions = []
    for dtype, ds in zip(['blur', 'noise', 'digital'], [blur, noise, digital]):
        distortions.extend([(dtype, d) for d in ds])

    results_path = './results/imagenet_c3.pkl'
    results = load_cached_results(results_path)
    if mkey not in results:
        results[mkey] = dict()
    error_rates = []
    print(results[mkey])
    for distortion in tqdm(distortions):
        distortion_key = '{}_{}'.format(*distortion)
        print(distortion_key)
        if distortion not in results[mkey]:
            try:
                rate = show_performance(distortion)
                results[mkey][distortion] = rate
                cache_results(results_path, results)
            except Exception as e:
                print(f'failed with {e} for {distortion_key}')
        else:
            print('already had that ho')
            rate = results[mkey][distortion]
        error_rates.append(rate)
        # print('Distortion: {:15s}  | CE (unnormalized) (%): {:.2f}'.format(distortion[1], 100 * rate))

    results[mkey]['average'] = np.mean(error_rates)

    print('mCE (unnormalized by AlexNet errors) (%): {:.2f}\n\n'.format(100 * np.mean(error_rates)))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluates robustness of various nets on ImageNet',
                                    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # # Architecture
    # parser.add_argument('--model_arch', '-a', type=str, choices=['resnet18', 'resnet50'])
    # parser.add_argument('--adv_train_norm', '-n', type=str, choices=['l2', 'linf'])
    # parser.add_argument('--adv_train_eps', '-e', type=str)

    # # Acceleration
    # parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
    args = parser.parse_args()
    # print(args)

    l2_epsilons = []#[0, 0.25, 0.5, 1, 3, 5]
    linf_epsilons = [0.5, 1.0, 2.0, 4.0, 8.0]
    linf_epsilons = [8.0]
    arch = 'wide_resnet50_2'
    for norm, epsilons in zip(['l2', 'linf'], [l2_epsilons, linf_epsilons]):
        for eps in epsilons:
            args.ngpu = 1
            args.model_arch, args.adv_train_norm, args.adv_train_eps = arch, norm, eps
            main(args)

    results = load_cached_results('./results/imagenet_c3.pkl')
    for mkey in results:
        errs = []
        for k in results[mkey]:
            errs.append(results[mkey][k])
        print('Model: {:20s}, Average Corruption Error: {:.3f}'.format(mkey, np.average(errs)))