from retrieve_imgs import *
from torchvision import transforms, models
import timm
import numpy as np
import pickle
from my_utils import *
from tqdm import tqdm

### Corruptions
blur_op = transforms.GaussianBlur(kernel_size=19, sigma=5)
gray_op = transforms.Grayscale()
invert_op = transforms.RandomInvert(p=1)

blur_fn = lambda img, mask: blur_op(img)
gray_fn = lambda img, mask: gray_op(img)
invert_fn = lambda img, mask: invert_op(img)

def general_mask_corrupter(imgs, ftr_masks, obj_masks, corruption_fn):
    cor_imgs = []
    for img, mask, obj_mask in zip(imgs, ftr_masks, obj_masks):
        mask = torch.clip(mask-obj_mask, 0, 1)
        corrupted_img = corruption_fn(img, mask)
        # mask = mask ** 0.5
        # mask = torch.vstack([mask]*3)
        mask[mask>0.5] = 1
        # mask[mask<1] = 0
        # print(mask.shape, img.shape)
        cor_imgs.append(img * (1-mask) + mask * corrupted_img)
    return torch.stack(cor_imgs)

def patch_operation(img, mask, patch_s=7, rotate=False, shuffle=False, thresh=0.25):
    # Randomly rotates and/or shuffles patches in masked region
    unfolder = torch.nn.Unfold(kernel_size=patch_s, stride=patch_s)
    folder = torch.nn.Fold((224,224), kernel_size=patch_s, stride=patch_s)

    img_patches, mask_patches = [unfolder(x.unsqueeze(0)) for x in [img, mask]]
    is_spur_patch = np.array([int(mask_patches[0,:,i].mean()>thresh) for i in range(mask_patches.shape[-1])])
    spur_patch_idx = np.where(is_spur_patch == 1)[0]
    if shuffle:
        idx = np.arange(is_spur_patch.sum())
        np.random.shuffle(idx)

        new_img_patches, new_mask_patches = [x.clone() for x in [img_patches, mask_patches]]
        for i,j in enumerate(idx):
            new_img_patches[0,:,spur_patch_idx[i]] = img_patches[0,:,spur_patch_idx[j]]
            new_mask_patches[0,:,spur_patch_idx[i]] = mask_patches[0,:,spur_patch_idx[j]]
        img_patches, mask_patches = new_img_patches, new_mask_patches
        
        # img_patches, mask_patches = [torch.stack([patches[0,:,spur_patch_idx[i]] for i in idx], -1).unsqueeze(0)
        #     for patches in [img_patches, mask_patches]]
    
    if rotate:
        new_img_patches, new_mask_patches = [x.clone() for x in [img_patches, mask_patches]]
        num_rots = np.random.randint(0,4,len(spur_patch_idx))
        for i, num_rot in zip(spur_patch_idx, num_rots):
            new_img_patches[0,:,i] = torch.rot90(img_patches[0,:,i].reshape((3, patch_s, patch_s)), num_rot, [1,2]).flatten()
            # new_mask_patches[0,:,i] = torch.rot90(mask_patches[0,:,i].reshape((3, patch_s, patch_s)), num_rot).flatten()
            
        img_patches, mask_patches = new_img_patches, new_mask_patches
    
    img, mask = [folder(x)[0] for x in [img_patches, mask_patches]]
    return img#, mask

patch_shuffle_fn = lambda img, mask: patch_operation(img, mask, patch_s=28, shuffle=True)
patch_rotate_fn = lambda img, mask: patch_operation(img, mask, patch_s=14, rotate=True)

corruption_dict = dict({'invert': invert_fn, 'blur': blur_fn, 'gray': gray_fn, 'patch_shuffle': patch_shuffle_fn, 'patch_rotate': patch_rotate_fn})

### Now we wish to use corruptions to assess sensitivities
'''
Pick maybe two models (resnet, deit) and look into comparative accuracy or probability drop 
per corruption. We can do this for all (cls, ftr) pairs in hardimagenet (and rival20?)
'''

def assess_sensitivities_for_classes(cls_idx=None):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    # results_path = 'results/spur_sensitivities.pkl'
    results_path = 'results/all_sensitivities.pkl'
    results = load_cached_results(results_path)

    with open('ftr_types/by_class_dict.pkl', 'rb') as f:
        ftrs_by_class = pickle.load(f)

    resnet = models.resnet50(pretrained=True).eval().cuda()
    deit = timm.create_model('deit_small_patch16_224', pretrained=True).eval().cuda()
    model_dict = dict({'resnet': resnet, 'deit': deit})

    for c in tqdm(range(1000)):
        for mname in ['resnet', 'deit']:
            if mname not in results:
                results[mname] = dict()
            if c not in results[mname]:
                results[mname][c] = dict()
            for ftr_type in ['core', 'spurious']:
                for f in ftrs_by_class[c][ftr_type]:
                    if f not in results[mname][c]:
                        results[mname][c][f] = dict()

                    if 'clean' not in results[mname][c][f]:
                        imgs, ftr_masks, obj_masks = get_imgs_and_ftr_masks(c, f, ftr_type)
                        results[mname][c][f]['clean'] = (model_dict[mname](normalize(imgs.cuda())).argmax(1) == c).sum().item()

                    for cor_key, cor_fn in corruption_dict.items():
                        if cor_key not in results[mname][c][f]:
                            cor_imgs = general_mask_corrupter(imgs, ftr_masks, obj_masks, cor_fn)
                            results[mname][c][f][cor_key] = (model_dict[mname](normalize(cor_imgs.cuda())).argmax(1) == c).sum().item()


                        cache_results(results_path, results)

if __name__ =='__main__':
    assess_sensitivities_for_classes()
