import matplotlib.pyplot as plt
from my_utils import *
import numpy as np
import matplotlib.patches as mpatches
from matplotlib import cm
from name_feature import *
from scipy.stats import pearsonr
import timm
from torchvision import models, transforms, utils
from assess_sensitivity import *
import os

def by_corruption_violins(gen_plot=True):
    results = load_cached_results('results/all_sensitivities.pkl')
    cor_keys = ['blur', 'gray', 'invert', 'patch_shuffle', 'patch_rotate']
    by_class_dict = load_cached_results('../ftr_types/by_class_dict.pkl')

    '''
    First, we'll show how ResNet and DeiT have variable robustness to each corruption.
    here, we plot the relative accuracy drops over all (cls,ftr) pairs we investigated.
    Later, we'll use these to normalize the intra model sensitivities
    '''
    rel_drops_by_cor = dict()
    for mname in ['resnet', 'deit']:
        rel_drops_by_cor[mname] = dict({k:dict({'core':[], 'spurious':[]}) for k in cor_keys})
        for c in results[mname]:
            for f in results[mname][c]:
                ftr_type = 'core' if f in by_class_dict[c]['core'] else 'spurious'
                for cor_key in cor_keys:
                    if results[mname][c][f]['clean'] != 0:
                        rel_drop = (results[mname][c][f]['clean']-results[mname][c][f][cor_key]) / max(results[mname][c][f]['clean'], results[mname][c][f][cor_key])
                        rel_drops_by_cor[mname][cor_key][ftr_type].append(rel_drop)

    if gen_plot:
        plt.style.use('seaborn-paper')
        data, data2, positions, xticks = [], [], [], []
        for i, cor_key in enumerate(cor_keys):
            data.append(rel_drops_by_cor['resnet'][cor_key]['core'] + rel_drops_by_cor['resnet'][cor_key]['spurious'])
            data.append(rel_drops_by_cor['deit'][cor_key]['core'] + rel_drops_by_cor['deit'][cor_key]['spurious'])


            data2.append(rel_drops_by_cor['resnet'][cor_key]['core'] + rel_drops_by_cor['deit'][cor_key]['core'])
            data2.append(rel_drops_by_cor['resnet'][cor_key]['spurious'] + rel_drops_by_cor['deit'][cor_key]['spurious'])

            positions.extend([2*i+0.4, 2*i+1.1])
            xticks.append(2*i+0.75)
        
        for d, labels, ext in zip([data, data2], [['ResNet', 'DeiT'], ['Core', 'Spurious']], ['by_model', 'by_ftr_type']):
            d = [np.array(sub) for sub in d]; d = [sub[sub>-0.25] for sub in d]
            # d = [np.array(sub_d)[np.a]]

            f, ax = plt.subplots(1,1, figsize=(4,4))
            parts = ax.violinplot(d, positions, widths=0.69, showextrema=False)
            ax.set_xticks(xticks)
            ax.set_xticklabels([ck.replace('_', '\n').title() for ck in cor_keys])
            colors = ['deepskyblue', 'coral']
            for i, p in enumerate(parts['bodies']):
                p.set_facecolor(colors[i % 2])
                p.set_edgecolor(colors[i % 2])
            # ax.set_ylim([-0.25, 1])
            ax.legend(handles=[mpatches.Patch(label=l, color=c) for l,c in zip(labels, colors)], loc='lower left')
            # ax.set_ylabel('Relative Accuracy Drop due to Spurious Region Corruption')
            ax.set_ylabel('Relative Accuracy Drop due to Feature Corruption')
            f.tight_layout(); f.savefig(f'plots/global_corruption_rel_drops_all_{ext}.jpg', dpi=300)

    return rel_drops_by_cor

def normalized_all_cf_pairs():
    '''
    Ok, here we are going to show everything. Our goal is two fold:
    1) are there any cls-ftr pairs that are much more sensitive to one corruption than others?
        --> here we normalize each (c,f, cor) rel drops by the average (_,_,cor) rel drop 
    2) do we get the same trends for two diff models, after normalization?

    We're gonna make a ton of massive bar plots lol. Like one plot for every cls, and each
    plot will have a separate set of bars for each ftr, and the set of bars will consist of
    two (one per arch) bars per corrruption type. Y axis is like a z-score
    '''

    fig, axs = plt.subplots(5,3, figsize=(15, 15))
    results = load_cached_results('results/spur_sensitivities.pkl')
    cor_keys = ['blur', 'gray', 'invert', 'patch_shuffle', 'patch_rotate']

    f_scat, ax_scat = plt.subplots(1,1, figsize=(4,4))
    resnet_z_scores, deit_z_scores = [dict({cor_key:[] for cor_key in cor_keys}) for i in range(2)]

    rel_drops_by_cor = by_corruption_violins(gen_plot=False)
    means = dict({mname:dict({cor_key: np.average(rel_drops_by_cor[mname][cor_key]['core']+rel_drops_by_cor[mname][cor_key]['spurious']) for cor_key in cor_keys}) for mname in ['resnet', 'deit']})
    stds = dict({mname:dict({cor_key: np.std(rel_drops_by_cor[mname][cor_key]['core']+rel_drops_by_cor[mname][cor_key]['spurious']) for cor_key in cor_keys}) for mname in ['resnet', 'deit']})

    cs = list(results['resnet'].keys())
    cmap = cm.get_cmap('tab20')
    colors = [cmap(i/10) for i in range(len(cor_keys))]#dict({l:cmap(i/20) for i, l in enumerate(cor_key)})
    handles = [mpatches.Patch(label=cor_key.replace('_', ' ').title(), color=c) for c,cor_key in zip(colors, cor_keys)]

    for i,c in enumerate(cs):
        ax = axs[i // 3, i % 3]
        xticks, xticklabels = [], []

        for j, f in enumerate(results['resnet'][c]):
            for k, cor_key in enumerate(cor_keys):
                resnet_rel_drop, deit_rel_drop = [(results[mname][c][f]['clean']-results[mname][c][f][cor_key]) / max(results[mname][c][f]['clean'], results[mname][c][f][cor_key])
                    for mname in ['resnet', 'deit']]
                resnet_z_score = (resnet_rel_drop -  means['resnet'][cor_key]) / stds['resnet'][cor_key]
                deit_z_score = (deit_rel_drop -  means['deit'][cor_key]) / stds['deit'][cor_key]
                ax.bar(2*j+0.25*k+0.05, resnet_z_score, color=colors[k], width=0.1)
                ax.bar(2*j+0.25*k+0.15, deit_z_score, color=colors[k], width=0.1, hatch='*')

                ax_scat.scatter(resnet_z_score, deit_z_score, color=colors[k], alpha=0.5)
                resnet_z_scores[cor_key].append(resnet_z_score); deit_z_scores[cor_key].append(deit_z_score)

                # print(choose_mode(f,c))
                xticklabels.append(f'Ftr {f}\n'+choose_mode(f,c, num_names=2).replace(',','\n'))
                xticks.append(2*j+0.75)

        ax.set_xticks(xticks); ax.set_xticklabels(xticklabels, fontsize=7)
        ax.legend(handles=handles, fontsize=6)
        ax.set_title(imagenet_classes[c].title())
        ax.set_ylabel('Normalized Sensitivity (z-score)')# w.r.t.\nmodel sensitivity to each corruption in general)')

    fig.tight_layout(); fig.savefig('plots/all_sensitivities.jpg', dpi=300)

    ### checking model consistency
    all_resnet_z_scores, all_deit_z_scores = [], []
    for cor_key, c in zip(cor_keys, colors):
        print(f'Correlation of z-scores bw ResNet and DeiT for {cor_key:<15}: {pearsonr(resnet_z_scores[cor_key], deit_z_scores[cor_key])[0]:.4f}')
        bestfitline = np.poly1d(np.polyfit(resnet_z_scores[cor_key], deit_z_scores[cor_key], 1))
        print(bestfitline)
        ax_scat.plot(np.unique(resnet_z_scores[cor_key]), bestfitline(np.unique(resnet_z_scores[cor_key])), '.-', color=c)
        all_resnet_z_scores.extend(resnet_z_scores[cor_key])
        all_deit_z_scores.extend(deit_z_scores[cor_key])
    print(f'Correlation over all corruptions: {pearsonr(all_resnet_z_scores, all_deit_z_scores)[0]:.4f}')
    bestfitline = np.poly1d(np.polyfit(all_resnet_z_scores, all_deit_z_scores, 1))
    ax_scat.plot(np.unique(all_resnet_z_scores), bestfitline(np.unique(all_resnet_z_scores)), '--', color='blue')
    ax_scat.legend(handles=handles, fontsize=7)
    f_scat.tight_layout(); f_scat.savefig('plots/scatter_sensitivities.jpg', dpi=300)


def sensitivity_profiles():
    '''
    For each cls-ftr pair, we can obtain a 'profile' telling us how sensitive it is to each corruption
    Profile is either:
        - absolute: just says z-score/percentile of relative accuracy drop per corruption type
        - relative: inspects z-score/percentile of rel acc drop for one corruption type vs. avg z-score/percentile for all corruptions
    Idea is that relative profile will tell us when a feature has a specific sensitivity
        e.g. only really sensitive to color corruptions
    Looking at extreme relative sensitivities will give me the plot I want (an example per corruption)
    '''
    results = load_cached_results('results/all_sensitivities.pkl')
    by_class_dict = load_cached_results('../ftr_types/by_class_dict.pkl')
    cor_keys = ['blur', 'gray', 'invert', 'patch_shuffle', 'patch_rotate']

    cls_ftr_pairs = []
    # ftr_type = 'core'
    ftr_type = 'spurious'
    for c in by_class_dict:
        cls_ftr_pairs.extend([(c,f) for f in by_class_dict[c][ftr_type]])

    resnet_z_scores, deit_z_scores = [dict({cor_key:[] for cor_key in cor_keys}) for i in range(2)]
    rel_drops_by_cor = by_corruption_violins(gen_plot=False)
    means = dict({mname:dict({cor_key: np.average(rel_drops_by_cor[mname][cor_key]['core']+rel_drops_by_cor[mname][cor_key]['spurious']) for cor_key in cor_keys}) for mname in ['resnet', 'deit']})
    stds = dict({mname:dict({cor_key: np.std(rel_drops_by_cor[mname][cor_key]['core']+rel_drops_by_cor[mname][cor_key]['spurious']) for cor_key in cor_keys}) for mname in ['resnet', 'deit']})

    z_scores_by_model = dict()#dict({(c,f):dict({mname:[] for mname in ['resnet', 'deit']}) for (c,f) in cls_ftr_pairs})

    ftr_profiles = dict(); relative_ftr_profiles = dict()
    for c,f in cls_ftr_pairs:
        z_scores = []
        # for mname in ['resnet', 'deit']:
        for mname in ['resnet']:
        # for mname in ['deit']:
            if results['resnet'][c][f]['clean'] != 0 and results['deit'][c][f]['clean'] != 0:
                rel_drops = [(results[mname][c][f]['clean']-results[mname][c][f][cor_key]) / max(results[mname][c][f]['clean'], results[mname][c][f][cor_key])
                                for cor_key in cor_keys]
                z_scores_by_model[mname] = [rel_drop for rel_drop, cor_key in zip(rel_drops, cor_keys)]
                z_scores.append(z_scores_by_model[mname])
        if len(z_scores) == 1:#2:
        # if len(z_scores) == 2:
        # print(list(np.average(z_scores,0)))
            # ftr_profiles[(c,f)] = dict({cor_key:avg for cor_key, avg in zip(cor_keys, list(np.average(z_scores,0)))})
            ftr_profiles[(c,f)] = dict({cor_key:avg for cor_key, avg in zip(cor_keys, z_scores[0])})
            avg_z = np.average(z_scores)
            relative_ftr_profiles[(c,f)] = dict({cor_key:(z-avg_z) for cor_key, z in ftr_profiles[(c,f)].items()})

    # for ptype, profiles in zip(['absolute', 'relative'], [ftr_profiles, relative_ftr_profiles]):
    # for ptype, profiles in zip(['absolute'], [ftr_profiles]):
    for ptype, profiles in zip(['relative'], [relative_ftr_profiles]):
        for cor_key in cor_keys:
            sorted_profiles = dict(sorted(profiles.items(), key=lambda item:item[1][cor_key]))
            highest, lowest = [list(sorted_profiles.keys())[x] for x in [-1, 1]]
            print(f'Highest {ptype} sensitivity to {cor_key} corruption: Feature {highest[1]} for class {imagenet_classes[highest[0]].title()} ({highest[0]})') 
            print(profiles[highest])   
            print(f'Lowest {ptype} sensitivity to {cor_key} corruption: Feature {lowest[1]} for class {imagenet_classes[lowest[0]].title()} ({lowest[0]})')
            print(profiles[lowest])    
            print()

            view_extreme_examples(highest[0], highest[1], ftr_type, cor_key, mname, wrong=True)
            view_extreme_examples(lowest[0], lowest[1], ftr_type, cor_key, mname, wrong=False)
        print('-'*30)

def view_extreme_examples(c, f, ftr_type, cor_key, mname, wrong=True):
    '''
    We'll use this function to make sense of our feature sensitivity analysis.
    As input we receive a cls-ftr pair, the corruption to do, a model to inspect, and
    an arg 'wrong' that when true tells us to return imgs where corruption leads to misclassification,
    and when False, tells us to return imgs where corruption leads to correct classfication (starting as wrong)
    '''
    imgs, ftr_masks, obj_masks = get_imgs_and_ftr_masks(c, f, ftr_type)
    cor_imgs = general_mask_corrupter(imgs, ftr_masks, obj_masks, corruption_dict[cor_key])

    if mname == 'resnet':
        model = models.resnet50(pretrained=True).eval().cuda()
    elif mname == 'deit':
        model = timm.create_model('deit_small_patch16_224', pretrained=True).eval().cuda()
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    clean_preds, cor_preds = [model(normalize(x.cuda())).argmax(1) for x in [imgs, cor_imgs]]
    clean_cc_idx, cor_cc_idx = [(preds == c) for preds in [clean_preds, cor_preds]]

    idx = clean_cc_idx * ~cor_cc_idx if wrong else ~clean_cc_idx * cor_cc_idx
    if idx.sum() == 0:
        return
    cool_clean_imgs, cool_cor_imgs, clean_preds, cor_preds = [x[idx] for x in [imgs, cor_imgs, clean_preds, cor_preds]]


    save_dir = f'plots/extreme_egs2/{ftr_type}/{"highest" if wrong else "lowest"}_{cor_key}_{mname}/'
    os.makedirs(save_dir, exist_ok=True)
    for i, (clean_img, cor_img, clean_pred, cor_pred) in enumerate(zip(cool_clean_imgs, cool_cor_imgs, clean_preds, cor_preds)):
        f, axs = plt.subplots(2,1,figsize=(3,6))
        _ = [axi.set_axis_off() for axi in axs]
        clean_title, cor_title = [imagenet_classes[pred].replace('_', ' ').title() for pred in [clean_pred, cor_pred]]
        axs[0].imshow(clean_img.numpy().swapaxes(0,1).swapaxes(1,2))
        axs[0].set_title(clean_title, color = 'darkgreen' if wrong else 'crimson', fontsize=15)
        axs[1].imshow(cor_img.numpy().swapaxes(0,1).swapaxes(1,2))
        axs[1].set_title(cor_title, color = 'crimson' if wrong else 'darkgreen', fontsize=15)
        f.savefig(save_dir+f"{i}_{clean_title.split(' ')[0]}_{cor_title.split(' ')[0]}.jpg", dpi=300, bbox_inches='tight', pad_inches=0.1)


    # imgs_to_save = torch.vstack([cool_clean_imgs[:5], cool_cor_imgs[:5]])
    # print(idx.sum())
    # grid = utils.make_grid(imgs_to_save, nrow=min(5, idx.sum())).numpy().swapaxes(0,1).swapaxes(1,2)
    # f, ax = plt.subplots(1,1, figsize=(4,2))
    # ax.imshow(grid)
    # ax.set_axis_off()

    # wrong_preds = ", ".join([imagenet_classes[(cor_preds if wrong else clean_preds)[i].item()] for i in range(min(5, idx.sum()))])
    # ax.set_title(f'{cor_key} corruption {"hurts" if wrong else "helps"} for class {imagenet_classes[c].title()}\nWrong Preds: {wrong_preds}',
    #     fontsize=4)
    # f.tight_layout(); f.savefig(f'plots/extreme_egs/{ftr_type}/{"highest" if wrong else "lowest"}_{cor_key}_{mname}2.jpg', dpi=300)


if __name__ == '__main__':
    # by_corruption_violins()
    # normalized_all_cf_pairs()
    sensitivity_profiles()

    # view_extreme_examples(379, 1120, 'spurious', 'gray', 'resnet')
    # view_extreme_examples(379, 1120, 'spurious', 'invert', 'resnet', wrong=False)
    # view_extreme_examples(379, 1120, 'spurious', 'invert', 'resnet')
    # view_extreme_examples(981, 1050, 'spurious', 'patch_shuffle', 'resnet')
