from spur_ftr_discovery import *
from my_utils import *
import matplotlib.pyplot as plt
from tqdm import tqdm
from paths import *

'''
Ok, so we observe 'confusing' spurious features, where a spurious feature for one class is 
an object belonging to another class. Arguably, these are labeling issues. 

So what we can do is use the segmentation masks from core ftr NAMs to better crop the image.
Specifically, we will 
1. Get high spuriosity images
2. Obtain core NAMs
3. Average and binarize them
4. Compute bounding box and make square
5. Crop
'''

split='train'
rankings = load_cached_results(f"results/new_img_rankings_by_idx_no_relu_{split}.pkl")
t = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor()])
dset = datasets.ImageNet(root=_IMAGENET_ROOT, split=split, transform=t)
model = load_robust_resnet('robust_resnet50_l2_eps3', just_the_model=False)


with open ('ftr_types/by_class_dict.pkl', 'rb') as f:
    by_class_dict = pickle.load(f)

# def square_crop(shorter_min, shorter_max, 
def resize_side(start, end, desired_len):
    mid = (start + end) / 2
    new_start = int(mid - desired_len / 2)
    new_end = int(mid + desired_len / 2)

    # print(start, end, desired_len)
    if new_start < 0:
        new_end = new_end - new_start
        new_start = 0
    elif new_end > 224:
        new_start = new_start - (new_end-224)
        new_end = 224
    return new_start, new_end 

def crop_from_nam(img, avg_nam, thresh=0.9):
    mask = 1 - avg_nam
    mask[mask < thresh] = 0
    on_pixels = np.where(mask > 0)
    if mask.sum() == 0:
        return img

    x_min, x_max = min(on_pixels[0]), max(on_pixels[0])
    y_min, y_max = min(on_pixels[1]), max(on_pixels[1])

    x_min, y_min = [int(0.8*x) for x in [x_min, y_min]]
    x_max, y_max = [int(1.2*x) for x in [x_max, y_max]]

    longer_side = max(x_max - x_min, y_max - y_min)
    xlen, ylen = x_max - x_min, y_max - y_min

    if xlen == 0 or ylen == 0:
        return img

    if xlen < ylen:
        x_min, x_max = resize_side(x_min, x_max, ylen)
    else:
        y_min, y_max = resize_side(y_min, y_max, xlen)
    
    cropped = img[:, x_min:x_max, y_min:y_max]
    # print(cropped.shape, x_min, x_max, y_min, y_max)
    return cropped

resize224 = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224)])
to_pil = transforms.ToPILImage()

def save_refined_imgs(c, save_root=''):
    # imgs = torch.stack([dset[i][0] for i in rankings[c]['spurious'][:36]])
    imgs = torch.stack([dset[i][0] for i in rankings[c]['spurious'][:4]])

    all_hmaps, all_nams =[], []
    for ftr in [1067]:#by_class_dict[c]['core']:
    # for ftr in by_class_dict[c]['core']:
        nams = compute_nams(model, imgs.cuda(), ftr)
        hmaps = compute_heatmaps(imgs, nams)
        all_hmaps.append(hmaps); all_nams.append(nams)

    avg_nams = torch.stack(all_nams).mean(0)
    # max_nam = all_nams[0]
    # for nam in all_nams:
    #     max_nam = torch.maximum(max_nam, nam)
    # print(max_nam.shape)
    avg_nams = torch.nan_to_num(avg_nams)
    cropped_imgs = torch.stack([resize224(crop_from_nam(img, avg_nam)) for img, avg_nam in zip(imgs, avg_nams)])
    # fig, axs = plt.subplots(1,2)
    # _ = [ax.set_axis_off() for ax in axs]
    # axs[0].imshow(make_grid(imgs, nrow=6).numpy().swapaxes(0,1).swapaxes(1,2))
    # axs[1].imshow(make_grid(cropped_imgs, nrow=6).numpy().swapaxes(0,1).swapaxes(1,2))
    # axs[0].set_title('Original High Spuriosity Images'); axs[1].set_title('Core Cropped Images')
    # fig.savefig(save_root+f"{imagenet_classes[c].replace(' ', '_')}.jpg", bbox_inches='tight', dpi=300)
    fig, axs = plt.subplots(2,1)
    fig.subplots_adjust(hspace=0)
    _ = [ax.set_axis_off() for ax in axs]
    # axs[0].imshow(make_grid(imgs, nrow=4).numpy().swapaxes(0,1).swapaxes(1,2))
    axs[0].imshow(make_grid(hmaps, nrow=4).numpy().swapaxes(0,1).swapaxes(1,2))
    axs[1].imshow(make_grid(cropped_imgs, nrow=4).numpy().swapaxes(0,1).swapaxes(1,2))
    # axs[0].set_title('Original High Spuriosity Images'); 
    axs[0].set_title('Core Feature Soft Segmentations (via NAMs)', fontsize=16); axs[1].set_title('Core Cropped Images', fontsize=16)
    fig.savefig("plots/core_cropped_carwheels.jpg", bbox_inches='tight', dpi=300)


def save_high_spuriosity_imgs(num=4):
    spurious_gaps = load_cached_results('results/new_spurious_gap.pkl')
    avg_gaps = dict({c:np.average([spurious_gaps[m]['gaps'][10][c] for m in spurious_gaps]) for c in spurious_gaps['dino_resnet50']['gaps'][10]})
    sorted_classes = [list(avg_gaps.keys())[i] for i in np.argsort(list(avg_gaps.values()))]

    imgs, classnames = [], []
    for c in sorted_classes[:3]:
        imgs.extend([dset[i][0] for i in rankings[c]['spurious'][:num]])
        classnames.append(imagenet_classes[c].title())#.replace(' ', '\n'))
    
    grid = make_grid(imgs, nrow=num).numpy().swapaxes(0,1).swapaxes(1,2)
    f, ax = plt.subplots(1,1)
    for side in ['top', 'left', 'right', 'bottom']:
        ax.spines[side].set_visible(False)
    ax.grid(False)
    ax.imshow(grid)
    ax.set_xticks([])
    ax.set_yticks([224*i+112 for i in range(3)])
    ax.set_yticklabels(classnames, rotation=90, va="center", fontsize=12)
    ax.set_title('High Spuriosity Images for Classes\n with Negative Spurious Gaps', fontsize=18)
    ax.set_ylabel('Class Label', fontsize=16)
    f.savefig('plots/mislabel_high_spur.jpg', dpi=300, bbox_inches='tight', pad_inches=0.1)


if __name__ == '__main__':
    # spurious_gaps = load_cached_results('results/new_spurious_gap.pkl')
    # avg_gaps = dict({c:np.average([spurious_gaps[m]['gaps'][10][c] for m in spurious_gaps]) for c in spurious_gaps['dino_resnet50']['gaps'][10]})
    # sorted_classes = [list(avg_gaps.keys())[i] for i in np.argsort(list(avg_gaps.values()))]
    # for c in tqdm(sorted_classes[:25]):
    #     # print(imagenet_classes[c])
    #     save_refined_imgs(c, 'refined/confusing/')

    # for c in tqdm(sorted_classes[-25:]):
    #     save_refined_imgs(c, 'refined/clarifying/')

    # save_high_spuriosity_imgs()
    save_refined_imgs(479)