from torchvision.transforms import Normalize, Compose, Resize, CenterCrop, ToTensor
from torchvision import datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import torch
from my_utils import *

_IMAGENET_ROOT = '' #REDACTED

normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

def standard_acc(model, loader, apply_norm=True, label_mapping=None, return_preds=False):
    # model = nn.DataParallel(model).eval().cuda()
    cc, ctr = 0, 0
    all_preds = []
    for dat in tqdm(loader):
        x, y = dat[0].cuda(), dat[-1].cuda()
        if label_mapping is not None:
            y = torch.LongTensor([label_mapping[y_i.item()] for y_i in y]).cuda()

        if apply_norm:
            x = normalize(x)
        preds = model(x).argmax(1)
        cc += (preds == y).sum().item()
        ctr += x.shape[0]

        if return_preds:
            all_preds.extend(preds.detach().cpu().numpy())
        # if ctr > 2500:
        #     break
    if return_preds:
        return cc / ctr, np.array(all_preds)
    else:
        return cc / ctr

### 'Spurious Gap' Metric -- evals on imagenet subsets determined by spurious ranking
def eval_spurious_gap(model, apply_norm=True, results=dict()):#, per_class=True):
    if 'gap' in results:
        return results['gap'], results
    
    img_idx = load_cached_results('results/img_rankings_by_idx_no_relu_val.pkl')
    t = Compose([Resize(256), CenterCrop(224), ToTensor()])
    dset = datasets.ImageNet(root=_IMAGENET_ROOT, split='val', transform=t)
    og_targets, og_imgs, og_samples = [np.array(x.copy()) for x in [dset.targets, dset.imgs, dset.samples]]

    cls_inds = [c for c in img_idx if 'spurious' in img_idx[c]]

    for key in ['top', 'bot']:

        results[key] = dict()
        split_idx = np.arange(25) if key == 'top' else np.arange(-1, -26,-1)
        idx = np.stack([img_idx[c]['spurious'][split_idx] for c in cls_inds]).flatten()
        
        dset.targets, dset.imgs, dset.samples = [x[idx] for x in [og_targets, og_imgs, og_samples]]
        loader = DataLoader(dset, batch_size=16, num_workers=16)
        preds, ys = [], []
        for x,y in loader:
            if apply_norm:
                x = normalize(x)
            preds.extend(model(x.cuda()).argmax(1).detach().cpu().numpy())
            ys.extend([int(a) for a in y])
        preds, ys = [np.array(x) for x in [preds, ys]]
        results[key]['raw'] = dict({'preds': preds, 'ys': ys})
        results[key]['acc'] = (preds == ys).mean()

    results['gap'] = results['top']['acc'] - results['bot']['acc']
    return results['gap'], results
    

def process_spurious_gap_results():
    results = load_cached_results('./results/spurious_gap.pkl')
    for m in results:
        per_class = dict()
        for top_or_bot in ['top', 'bot']:
            per_class[top_or_bot] = dict()
            preds, ys = [results[m][top_or_bot]['raw'][x] for x in ['preds', 'ys']]
            for i in range(0, ys.shape[0], 25):
                cc = 0
                cls_preds = preds[i:i+25]
                cls_y = ys[i]
                per_class[top_or_bot][cls_y] = dict()
                for top_k in [5,10,15]:

                    per_class[top_or_bot][cls_y][top_k] = (cls_preds[:top_k] == cls_y).mean()
        
        # print(per_class)
        results[m]['accs'] = per_class.copy()
        results[m]['gaps'] = dict({top_k:dict({y:per_class['top'][y][top_k] - per_class['bot'][y][top_k] 
                                        for y in per_class['top']}) for top_k in [5,10,15]})
        results[m]['avg_gaps'] = dict({
            'avg_gap': dict({top_k:np.average(list(results[m]['gaps'][top_k].values())) for top_k in [5,10,15]})
        })
    cache_results('./results/spurious_gap.pkl', results)


if __name__ == '__main__':
    process_spurious_gap_results()