from PIL import Image
import glob
import torch
from torchvision import transforms
import pickle

transform224 = transforms.Compose([transforms.Resize((224,224)), transforms.CenterCrop(224), transforms.ToTensor()])
transform256 = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])

with open('../ftr_types/idx_to_wnid.pkl', 'rb') as f:
    idx_to_wnid = pickle.load(f)

 
def get_imgs_and_ftr_masks(c, f, ftr_type):
    '''
    When (c,f) is spurious, we want to return object masks as well as the image and spurious masks
    For core (c,f), just the images and core masks suffice. 
    '''
    wnid = idx_to_wnid[c]
    root = f'{_SALIENT_IMAGNET_ROOT}/{wnid}/feature_{f}/'

    img_fs = glob.glob(root+'/*')
    img_fs = [_IMAGENET_ROOT+f'{p.split("/")[-1].split("_")[0]}/{p.split("/")[-1]}' for p in img_fs]
    imgs = [transform256(Image.open(f)) for f in img_fs]
    ftr_masks = [transform224(Image.open(f)) for f in img_fs]

    if ftr_type == 'spurious':
        if c in hard_imagenet_idx:
            hardimagenet_root = '../../data/hardImageNet/'
            obj_masks = [transform256(Image.open(hardimagenet_root+'train/'+wnid+'_'+f.split('/')[-1].replace('.jpeg', '.JPEG'))) for f in img_fs]
        else:
            consolidated_core_mask_dir = f'../../data/core_masks/class_{c}/'
            obj_masks = [transform224(Image.open(consolidated_core_mask_dir+f.split('/')[-1].replace('JPEG', 'jpeg'))) for f in img_fs]
    else:
        # for core masks, we do not need to worry about not corrupting the object
        obj_masks = [torch.zeros(imgs[0].shape)] * len(imgs)

    ftr_masks, obj_masks, imgs = [[m if m.shape[0] == 3 else torch.vstack([m,m,m]) for m in masks] for masks in [ftr_masks, obj_masks, imgs]]
    imgs, ftr_masks, obj_masks = [torch.stack(l) for l in [imgs, ftr_masks, obj_masks]]

    
    return imgs, ftr_masks, obj_masks


