import logging
import os
import datetime
import torchvision.models as models
import math
import torch
import yaml
from easydict import EasyDict
import shutil
import glob
import pickle
from robustness import datasets as robust_dsets
from robustness import model_utils as robust_model_utils
import numpy as np

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(initial_lr, optimizer, epoch, n_repeats):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = initial_lr * (0.1 ** (epoch // int(math.ceil(30./n_repeats))))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def fgsm(gradz, step_size):
    return step_size*torch.sign(gradz)



def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def initiate_logger(output_path, evaluate):
    if not os.path.isdir(os.path.join('output', output_path)):
        os.makedirs(os.path.join('output', output_path))
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    logger.addHandler(logging.FileHandler(os.path.join('output', output_path, 'eval.txt' if evaluate else 'log.txt'),'w'))
    logger.info(pad_str(' LOGISTICS '))
    logger.info('Experiment Date: {}'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M')))
    logger.info('Output Name: {}'.format(output_path))
    logger.info('User: {}'.format(os.getenv('USER')))
    return logger

def get_model_names():
	return sorted(name for name in models.__dict__
    		if name.islower() and not name.startswith("__")
    		and callable(models.__dict__[name]))

def pad_str(msg, total_len=70):
    rem_len = total_len - len(msg)
    return '*'*int(rem_len/2) + msg + '*'*int(rem_len/2)\

def parse_config_file(args):
    with open(args.config) as f:
        config = EasyDict(yaml.load(f))
        
    # Add args parameters to the dict
    for k, v in vars(args).items():
        config[k] = v
        
    # Add the output path
    # config.output_name = '{:s}_step{:d}_eps{:d}_repeat{:d}'.format(args.output_prefix,
    #                      int(config.ADV.fgsm_step), int(config.ADV.clip_eps), 
    #                      config.ADV.n_repeats)
    config.output_name = args.output_prefix
    return config


def save_checkpoint(state, is_best, filepath, epoch):
    filename = os.path.join(filepath, f'checkpoint_epoch{epoch}.pth.tar')
    # Save model
    torch.save(state, filename)
    # Save best model
    if is_best:
        shutil.copyfile(filename, os.path.join(filepath, 'model_best.pth.tar'))


def print_precision(key):
    ckpt = torch.load('./results/trained_models/{}_phase2/model_best.pth.tar'.format(key))
    print(ckpt['best_prec1'])


def l2_normalize(tens):
    # normalizes tensor along batch dimension
    norms = torch.norm(tens.view(tens.shape[0], -1), p=2, dim=1)
    factor = torch.ones(tens.shape, device=tens.device)
    factor = factor / norms.view(-1,1,1,1)
    normalized = tens * factor
    return normalized

def rel_score(core_acc, spur_acc):
    '''
    Computes relative core sensitivity for scalar values core_acc and spur_acc
    '''
    avg = 0.5*(core_acc+spur_acc)
    return 0 if (avg == 1 or avg == 0) else (core_acc - spur_acc) / (2*min(avg, 1-avg))


def clean_up(root='./results/trained_models/'):
    dirs = glob.glob(root+'*')
    ckpts_to_delete = ['checkpoint_epoch{}.pth.tar'.format(x) for x in range(1,6)]
    # ckpts_to_delete = ckpts_to_delete + ['checkpoint_epoch{}.pth.tar'.format(x) for x in range(7,15)]
    # ckpts_to_delete = ckpts_to_delete + ['checkpoint_epoch{}.pth.tar'.format(x) for x in range(16,25)]

    for d in dirs:
        for f in glob.glob(d+'/*'):
            fname = f.split('/')[-1]
            if fname in ckpts_to_delete:
                print('Removing ', f)
                os.remove(f)

def cnt_params(model):
    cnt = 0
    for p in model.parameters():
        cnt += p.reshape(-1).shape[0]
    return cnt


def combine_cached_results():
    combined_dict = dict()
    template_path = './cached_results/pretrained_models{}.pkl'
    for suffix in ['', 2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17]:
    # for suffix in [1,2,3,4]:
        with open(template_path.format(suffix), 'rb') as f:
            d = pickle.load(f)
        for mtype in d:
            if mtype not in combined_dict:
                combined_dict[mtype] = dict()
            for sigma in d[mtype]:
                combined_dict[mtype][sigma] = d[mtype][sigma]

    for mtype in combined_dict:
        ks = list(combined_dict[mtype].keys())
        # ks.remove('clean_acc')
        print('Model: {}, Keys: {}'.format(mtype, sorted(ks)))

    with open('./cached_results/pretrained_models_combined.pkl', 'wb') as f:
        pickle.dump(combined_dict, f)


def load_robust_resnet(mtype):
    resume_root = '/REDACTED/dcr_models/pretrained-robust/{}.ckpt'.format(mtype)
    add_custom_forward = True
    if 'wide' in mtype:
        arch = models.wide_resnet50_2() if '50' in mtype else models.wide_resnet101_2()
    elif 'mobilenet' in mtype:
        arch = models.mobilenet_v2()
    elif 'shufflenet' in mtype:
        arch = models.shufflenet_v2_x1_0()
    elif 'vgg' in mtype:
        arch = models.vgg16_bn()
    elif 'densenet' in mtype:
        arch = models.densenet161()
    elif 'resnext50' in mtype:
        arch=models.resnext50_32x4d()
    else:
        _, arch, norm, suffix = mtype.split('_')
        eps = suffix[3:] # 'eps0.25'
        resume_root = '/REDACTED/dcr_models/pretrained-robust/{}_{}_eps{}.ckpt'.format(arch, norm, eps)
        add_custom_forward = False

    # dataset_function = getattr(robustness.datasets, 'ImageNet')
    dataset_function = getattr(robust_dsets, 'ImageNet')
    dataset = dataset_function('/scratch1/shared/datasets/ILSVRC2012/')
    model_kwargs = {
            'arch': arch,
            'dataset': dataset,
            # 'resume_path': '/scratch1/ssingla/causal_imagenet/models/robust_resnet50.pth',
            # 'resume_path': '/REDACTED/dcr_models/pretrained-robust/resnet50_l2_eps3.ckpt',
            'resume_path': resume_root,
            'parallel': False,
            'add_custom_forward': add_custom_forward
        }
    # model, _ = robustness.model_utils.make_and_restore_model(**model_kwargs)
    model, _ = robust_model_utils.make_and_restore_model(**model_kwargs)
    model = model.model
    model.eval()
    return model

def from_scratch_stats_by_key(k):
    with open('./cached_results/from_scratch.pkl', 'rb') as f:
        trial1 = pickle.load(f)
    with open('./cached_results/from_scratch_trial2.pkl', 'rb') as f:
        trial2 = pickle.load(f)
    with open('./cached_results/from_scratch_trial3.pkl', 'rb') as f:
        trial3 = pickle.load(f)
    with open('./cached_results/from_scratch4.pkl', 'rb') as f:
        trial4 = pickle.load(f)
    # with open('./cached_results/from_scratch5.pkl', 'rb') as f:
    #     trial5 = pickle.load(f)
    trial_dirs = [trial1, trial2, trial3, trial4]#, trial5]

    cleans, cores, spurs, rcas = [], [], [], []
    for i,t in enumerate(trial_dirs):#[trial1, trial2, trial3]:
        if k in t:
            core, spur, rca = [t[k][0.25][x] for x in ['core', 'spur', 'rca']]
            # if 'clean_acc' in t[k]:
            #     cleans.append(t[k]['clean_acc'].item())
            # else:
            #     cleans.append(0)
            if 0.0 not in t[k]:
                print(i, k, t[k].keys())
            else:
                cleans.append(t[k][0.0]['core'])
            cores.append(core)
            spurs.append(spur)
            rcas.append(rca)

    avgs = [np.average(x) for x in [cleans, cores, spurs, rcas]]
    stds = [np.std(x) for x in [cleans, cores, spurs, rcas]]
    print('Model:{:<20}, trials:{}, {:.2f}\\% $\\pm$ {:.2f} &{:.2f}\\% $\\pm$ {:.2f} & {:.2f}\\% $\\pm$ {:.2f} & {:.2f}\\% $\\pm$ {:.2f}'.format(
        k, len(avgs), avgs[0], stds[0], avgs[1], stds[1], avgs[2], stds[2], avgs[3], stds[3]
    ))

    # print('Averages: ', [np.average(x) for x in [cores, spurs, rcas]])
    # print('Stds: ', [np.std(x) for x in [cores, spurs, rcas]])

    # print(cores)
    # print(spurs)
    # print(rcas)

def print_corm_results():
    ks = ['baseline', 'noise_0.25_half', 'sal_reg_1', 'sal_reg_1_noise_0.25_half']
    for k in ks:
        from_scratch_stats_by_key(k)

def remove_outer_axes(f, ax):
    ax = f.add_axes([0,0,1,1])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    return f,ax

def compute_and_save_balanced_accs(fname):
    with open('./cached_results/cnts_per_class.pkl', 'rb') as f:
        cnts = pickle.load(f)    

    with open('./cached_results/{}'.format(fname), 'rb') as f:
        d = pickle.load(f)

    num_nonzero_classes = sum([(cnts[c]>0) for c in cnts])
    balanced_d = dict()
    print(d.keys())
    for k in d:
        balanced_d[k] = dict()
        for sigma in d[k]:
            # if sigma == 'clean_acc':
            #     balanced_d[k][sigma] = d[k][sigma]
            # else:
            balanced_d[k][sigma] = dict()
            # print(d[k].keys())
            core_by_class, spur_by_class = [d[k][sigma]['{}_by_class'.format(x)] for x in ['core', 'spur']]
            core = sum([core_by_class[c] for c in core_by_class]) / num_nonzero_classes
            spur = sum([spur_by_class[c] for c in spur_by_class]) / num_nonzero_classes
            rca = rel_score(core, spur)
            # print(balanced_d[k][sigma]['core'], balanced_d[k][sigma]['spur'])
            balanced_d[k][sigma]['core'] = 100. * core
            balanced_d[k][sigma]['spur'] = 100. * spur
            balanced_d[k][sigma]['rca'] = 100. * rca
            # print(balanced_d[k][sigma]['rca'])
            # print()
            balanced_d[k][sigma]['core_acc_by_class'] = dict({c:core_by_class[c] for c in core_by_class if cnts[c] > 0})
            balanced_d[k][sigma]['spur_acc_by_class'] = dict({c:spur_by_class[c] for c in spur_by_class if cnts[c] > 0})

    with open('./cached_results_balanced/{}'.format(fname), 'wb') as f:
        pickle.dump(balanced_d, f)

# compute_and_save_balanced_accs('pretrained_models_combined.pkl')
# print_corm_results()
