import torch
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from torchattacks import PGDL2
from tqdm import tqdm
import numpy as np
from my_utils import *
from robustness.datasets import ImageNet
from robustness.model_utils import make_and_restore_model
from collections import OrderedDict
from pretrained_model_eval import *
from paths import *

class Trainer(object):
    def __init__(self, mkey='resnet18', dset='waterbirds', epochs=20, trial=None):
        # configure data
        self.dset = dset
        self.init_loaders()
        # configure model
        self.mkey = mkey
        self.finetune = ('finetune' in mkey)
        self.init_model(mkey)
        # configure training + saving pipeline
        self.optimizer = torch.optim.Adam(self.parameters, lr=0.0003, 
                                    betas=(0.9,0.999), weight_decay=0.003)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=epochs)

        self.save_path = f'finetuned/{dset}.pth'
        self.criterion = torch.nn.CrossEntropyLoss()
        self.best_acc = 0
        self.num_epochs = epochs

    def init_model(self, mkey):
        if self.finetune:
            train_ds = ImageNet('/tmp')
            net, _ = make_and_restore_model(arch='resnet50', dataset=train_ds, resume_path='/cmlscratch/mmoayeri/models/pretrained-robust/resnet50_l2_eps3.ckpt',
                parallel=False, add_custom_forward=False)
            feat_net = torch.nn.Sequential(OrderedDict([('normalizer',net.normalizer), *list(net.model.named_children())[:-1]]))
            self.model = torch.nn.Sequential()
            self.model.add_module('feat_net', feat_net)
            self.model.add_module('flatten', torch.nn.Flatten())
            in_ftrs = 2048#model(torch.zeros(5,3,224,224).to(device)).shape[1]#512 if arch == 'resnet18' else 2048
            self.model.add_module('classifier', torch.nn.Linear(in_features=in_ftrs, out_features=self.num_classes, bias=True))
            self.gradcam_layer = self.model.feat_net.layer4[-1]

            self.parameters = list(self.model.classifier.parameters())
            # freeze all non-final-layer parameters
            for param in self.model.feat_net.parameters():
                param.requires_grad = False
            self.model = self.model.cuda()
        else:
            self.model = models.resnet18(pretrained=False)
            self.model.fc = torch.nn.Linear(in_features=512, out_features=self.num_classes, bias=True)

            self.model = self.model.cuda()#to(device)
            self.parameters = self.model.parameters()

        self.attack = PGDL2(self.model, eps=3.0, alpha=0.5, steps=10, random_start=True)

    def init_loaders(self, shuffle=True):
        if self.dset == 'waterbirds' or self.dset == 'celebA':
            dataset = get_dataset(self.dset, root_dir=_WATERBIRDS_CELEBA_ROOT)
            transform = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor()])
            train_dset, test_dset = [dataset.get_subset(s, transform=transform) for s in ['train', 'test']]
            if shuffle:
                train_loader, test_loader = [get_train_loader("standard", d, batch_size=32, num_workers=16) for d in [train_dset, test_dset]]
            else:
                train_loader, test_loader = [get_eval_loader("standard", d, batch_size=32, num_workers=16) for d in [train_dset, test_dset]]
            self.num_classes = 2
        elif 'utkface' in self.dset:
            target = self.dset.split('_')[-1]
            train_dset, test_dset = [UTKFace(target=target, split=s, transform=transform) for s in ['train', 'test']]
            train_loader, test_loader = [torch.utils.data.DataLoader(dset, batch_size=32, num_workers=16, shuffle=shuffle) 
                                            for dset in [train_dset, test_dset]] 
            self.num_classes = 5 if target == 'race' else 2

        self.loaders = dict({phase:loader for phase, loader in zip(['train', 'test'], [train_loader, test_loader])})

    def save_model(self):
        self.model.eval()
        save_dict = dict({'state': self.model.state_dict(), #'linear_layer': self.model.classifier.state_dict(), 
                          'acc': self.best_acc})
        torch.save(save_dict, self.save_path)
        print('\nSaved model with accuracy: {:.3f} to {}\n'.format(self.best_acc, self.save_path))

    def restore_model(self):
        print('Loading model from {}'.format(self.save_path))
        save_dict = torch.load(self.save_path)
        self.model.load_state_dict(save_dict['state'])
        self.model.eval()
        self.test_acc = save_dict['acc']

    def gradcam_layer(self):
        return self.gradcam_layer

    def process_epoch(self, phase):
        if phase == 'train':
            self.model.train()
        else:
            self.model.eval()
        correct, running_loss, total = 0, 0, 0
        for dat in tqdm(self.loaders[phase]):
            dat = [d.cuda() for d in dat]
            x,y = dat[0], dat[1]
            self.optimizer.zero_grad()

            # x = self.attack(x, y)

            logits = self.model(x)
            loss = self.criterion(logits, y)
            if phase == 'train':
                loss.backward()
                self.optimizer.step()
            y_pred = logits.argmax(dim=1)
            correct += (y_pred == y).sum()
            total += x.shape[0]
            running_loss += loss.item()
        avg_loss, avg_acc = [stat/total for stat in [running_loss, correct]]
        return avg_loss, avg_acc
    
    def train(self):
        print('\nBeginning training of model to be saved at {}\n'.format(self.save_path))
        for epoch in range(self.num_epochs):
            train_loss, train_acc = self.process_epoch('train')
            if (epoch+1) % 2 == 0:
                _, test_acc = self.process_epoch('test')
                print(test_acc)
                if test_acc > self.best_acc:
                    self.best_acc = test_acc
                    self.save_model()

                if (epoch+1) % (self.num_epochs // 3) == 0:
                    self.scheduler.step()
            print('Epoch: {}/{}......Train Loss: {:.3f}......Train Acc: {:.3f}'.format(epoch, self.num_epochs, train_loss, train_acc))
        test_loss, test_acc = self.process_epoch('test')
        print('Test Loss: {:.3f}......Test Acc: {:.3f}'.format(test_loss, test_acc))
        print('\n\Training Complete\n\n')

### Analysis

def obtain_features(trial=None, dset='waterbirds', no_relu=False):
    trainer = Trainer(trial=trial, mkey='finetune', dset=dset); trainer.restore_model()

    model = trainer.model
    trainer.init_loaders(shuffle=False)
    loader = trainer.loaders['train']

    ftr_encoder = model.feat_net#torch.nn.Sequential(*list(trainer.model.children())[:-1])

    if no_relu:
        ftr_encoder.layer4[2].relu = torch.nn.Identity()

    all_ftrs, all_preds, all_ys, all_bgs = [], [], [], []
    for dat in tqdm(loader):
        x = dat[0].cuda()
        ftrs = ftr_encoder(x).flatten(1)
        preds = model.classifier(torch.nn.ReLU()(ftrs)).argmax(1)

        all_ftrs.extend(ftrs.detach().cpu().numpy())
        all_preds.extend(preds.detach().cpu().numpy())
        all_ys.extend(dat[1].numpy())
        if 'utkface' not in dset:
            all_bgs.extend(dat[2][:,0].numpy())

    all_ftrs, all_preds, all_ys, all_bgs = [np.array(x) for x in [all_ftrs, all_preds, all_ys, all_bgs]] 

    print(np.average(all_ys == all_preds))
    print(all_ftrs.shape)
    save_dict = dict({'ftrs':all_ftrs, 'preds': all_preds, 'ys': all_ys, 'bgs':all_bgs})
    cache_results(f"results/saved_ftrs_{dset}{'_no_relu' if no_relu else ''}.pkl", save_dict)

def find_important_ftrs(trial=2, dset='waterbirds'):
    results = load_cached_results(f'results/saved_ftrs_{dset}.pkl')
    ftrs, preds, ys, bgs = [results[x] for x in ['ftrs', 'preds', 'ys', 'bgs']]

    trainer = Trainer(trial=trial, mkey='finetune', dset=dset); trainer.restore_model()
    W = trainer.model.classifier.weight

    cc = np.where(ys == preds)[0]
    cc_0, cc_1 = [cc[ys[cc] == cls_ind] for cls_ind in [0,1]]

    most_important_ftrs = dict()
    for i, inds in enumerate([cc_0, cc_1]):
        avg_ftrs = ftrs[inds].mean(0)
        importances = avg_ftrs * W[i].detach().cpu().numpy()
        most_important_ftrs[i] = np.argsort(-1*importances)

    cache_results(f"results/important_ftrs_idx_{dset}.pkl", most_important_ftrs)

# visualizations: top activating imgs, hmaps, and ftr attacks
to_pil = transforms.ToPILImage()
def gen_visualizations(c, f, ftr_rank, ftrs, ys, train_dset, dsetname, model, specific_idx=None):
    # Get top activating images
    in_cls_idx = np.where(ys==c)[0]
    ftrs = ftrs[in_cls_idx]
    idx = np.argsort(-1*ftrs[:, f])

    # print(idx[0], idx[-1])

    if specific_idx is None:
        top_idx = idx[:10]
        bot_idx = idx[-10:]
        ext = ''
        nrow=10
        include_ftr_atks = True
    else:
        top_idx = [ind for j,ind in enumerate(idx) if j in specific_idx[0]]
        bot_idx = [ind for j,ind in enumerate(idx) if j in [len(idx)-(10-x) for x in specific_idx[1]]]
        # ext = 'for_paper_'
        ext = 'for_supp_'
        nrow=len(specific_idx[0])
        include_ftr_atks = True#False

    top_imgs = torch.stack([train_dset[in_cls_idx[i]][0] for i in top_idx])
    bot_imgs = torch.stack([train_dset[in_cls_idx[i]][0] for i in bot_idx])
    nams = compute_nams(model, top_imgs, f)
    hmaps = compute_heatmaps(top_imgs, nams)
    if include_ftr_atks:
        ftr_atks, _ = feature_attack(model, top_imgs, [f])
    
    if include_ftr_atks:
        all_viz = torch.vstack([top_imgs, hmaps, ftr_atks, bot_imgs])
    else:
        all_viz = torch.vstack([top_imgs, hmaps, bot_imgs])

    to_pil(make_grid(all_viz, nrow=nrow, pad=20, pad_value=1)).save(f'plots2/{dsetname}/{ext}cls_{c}_ftr_rank_{ftr_rank}_ftr_{f}.jpg')
    print('Saving to '+f'plots2/{dsetname}/{ext}cls_{c}_ftr_rank_{ftr_rank}_ftr_{f}.jpg')

def visualize_important_ftrs(num_ftrs_to_visualize=15, dset='waterbirds'):
    results = load_cached_results(f'results/saved_ftrs_{dset}.pkl')
    ftrs, preds, ys, bgs = [results[x] for x in ['ftrs', 'preds', 'ys', 'bgs']]

    important_ftr_idx = load_cached_results(f'results/important_ftrs_idx_{dset}.pkl')

    # trainer = Trainer(trial=2); trainer.restore_model()
    trainer = Trainer(mkey='finetune', dset=dset); trainer.restore_model()
    model = trainer.model
    trainer.init_loaders(shuffle=False)
    train_dset = trainer.loaders['train'].dataset

    for c in important_ftr_idx:
        for ftr_rank, f in enumerate(important_ftr_idx[c][:num_ftrs_to_visualize]):
            gen_visualizations(c, f, ftr_rank, ftrs, ys, train_dset, dset, model)
    # return

def cool_examples():
    '''
    Just saving a couple examples of discovered spurious features for the paper
    Specifically, we want:
        - For waterbirds: cls0 ftr rank 3 (1220) imgs 5-7, cls1 ftr rank 0 (1697) imgs 2-4
        - For celeb A: cls0 ftr rank 6 (754) imgs 2-4, cls1 ftr rank 14 (1139) imgs 2-4
    '''

    # for paper
    # egs_by_dset = dict({
    #     'waterbirds': [[0,1], [2,0], [513, 1697], [[[0,1,8], [3,5,9]], [[2,3,4], [0,8,9]]]],
    #     'celebA': [[0,1], [6,14], [753, 1139], [[[2,3,4], [0,5,9]],[[2,3,4], [3,4,8]]]]
    # })

    # for supplemental
    egs_by_dset = dict({
        'utkface_race': [[3], [0], [824], [[[0,4,3], [0,2,9]]], ['Indian'], ['Glasses']], 
        'utkface_gender': [[0,1], [3,3], [250,1799], [[[0,1,2],[0,1,5]], [[0,2,8],[0,1,2]]], ['Male', 'Female'], ['Suit \& Tie', 'Pink (color)']]
        'waterbirds': [[0,1], [3, 1], [1707, 613], [[[0,1,2], [0,1,2]], [[0,1,2], [0,1,2]]]],
        'celebA': [[0,0,1,1], [0, 1, 0, 2], [1709, 1815, 188, 4], [[[0,1,2], [1,3,6]], [[0,1,2], [1,4,6]], [[0,1,4], [0,1,5]], [[0,1,2], [3,6,8]]]]
    })

    for dset in egs_by_dset:
        results = load_cached_results(f'results/saved_ftrs_{dset}.pkl')
        no_relu_results = load_cached_results(f'results/saved_ftrs_{dset}_no_relu.pkl')
        no_relu_ftrs, preds, ys, bgs = [results[x] for x in ['ftrs', 'preds', 'ys', 'bgs']]
        # no_relu_ftrs, _, _, _ = [no_relu_results[x] for x in ['ftrs', 'preds', 'ys', 'bgs']]

        trainer = Trainer(mkey='finetune', dset=dset); trainer.restore_model()
        model = trainer.model
        trainer.init_loaders(shuffle=False)
        train_dset = trainer.loaders['train'].dataset

        for c, ftr_rank, f, specific_idx in zip(*egs_by_dset[dset]):
            # gen_visualizations(c, f, ftr_rank, ftrs, ys, train_dset, dset, model, specific_idx)
            gen_visualizations(c, f, ftr_rank, no_relu_ftrs, ys, train_dset, dset, model, specific_idx)

######## COMPUTING SPURIOUS RANKINGS
def obtain_all_imgnet_robust_ftrs(no_relu=True, mkey='robust_resnet50_l2_eps3'):
    apply_norm=True
    if 'dino' in mkey:
        model = DinoWrapper(mkey)
        m = mkey
    elif 'clip' in mkey:
        m = mkey[5:].replace('_', '/')
        model = ClipZeroShot(m)
        apply_norm=False
    elif 'robust' in mkey:
        model = load_robust_resnet(mkey)
        m = mkey
    elif 'moco' in mkey:
        if 'vit' in mkey:
            arch = 'vit_small' if 'vit-s' in mkey else 'vit_base'
            model = vits.__dict__[arch]()
            m = 'moco_{}'.format(arch)
            key = mkey[5:] + '-'
        else:
            model = torchvision.models.resnet50()
            m = 'moco_resnet50'
            key = ''
        ckpt = torch.load('{}/linear-{}300ep.pth.tar'.format(_MOCO_ROOT, key))
        og_states = ckpt['state_dict']
        state_dict = dict({k[len('module.'):]:og_states[k] for k in og_states})
        model.load_state_dict(state_dict)
    elif 'simclr' in mkey:
        m = mkey
        model = resnet_wider.__dict__[mkey[7:]]()
        ckpt = torch.load('{}/resnet50-{}x.pth'.format(_SIMCLR_ROOT, m[-1]))
        model.load_state_dict(ckpt['state_dict'])
        apply_norm=False
    
    model = model.cuda().eval()
    if no_relu:
        model.layer4[2].relu = torch.nn.Identity()
    save_dict = dict()
    save_path = 'results/{}_ftrs{}.pkl'.format(mkey, '_no_relu' if no_relu else '')

    ftr_encoder = torch.nn.Sequential(*list(model.children())[:-1])

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    t = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])    
    for split in ['train', 'val']:
        dset = datasets.ImageNet(root=_IMAGENET_ROOT, split=split, transform=t) 
        loader = torch.utils.data.DataLoader(dset, batch_size=64, shuffle=False, num_workers=16)
        ftrs, ys = [], []
        for x,y in tqdm(loader):
            if apply_norm:
                x = normalize(x)
            batch_ftrs = ftr_encoder(x.cuda()).flatten(1).detach().cpu().numpy()
            ys.extend(y.numpy())
            ftrs.extend(batch_ftrs)
            # if len(ys) > 250:
            #     break
        ftrs, ys = [np.array(x) for x in [ftrs, ys]]
        print(ftrs.shape, ys.shape, ys[:10])
        save_dict[split] = dict({'ftrs': ftrs, 'ys': ys})
        print(f'Done for {split} split. Saving to {save_path}')
        cache_results(save_path, save_dict)

        rand_idx = np.arange(ftrs.shape[0])
        np.random.shuffle(rand_idx)
        ftrs, ys = [x[rand_idx[:1000]] for x in [ftrs, ys]]
        if no_relu:
            print((model.fc(torch.tensor(ftrs).cuda()).argmax(1) == torch.LongTensor(ys).cuda()).sum() / ys.shape[0])
        else:
            print((model.fc(torch.nn.ReLU()(torch.tensor(ftrs)).cuda()).argmax(1) == torch.LongTensor(ys).cuda()).sum() / ys.shape[0])

def imagenet_minority_examples(num_classes=25, no_relu=True, split='train'):
    by_class_dict = load_cached_results('ftr_types/by_class_dict.pkl')
    wnid_to_idx = load_cached_results('ftr_types/wnid_to_idx.pkl')
    idx_to_wnid = dict({v:k for k,v in wnid_to_idx.items()})

    sorted_by_cls = dict(sorted(by_class_dict.items(), key=lambda item: -1*len(item[1]['spurious'])))
    ftrs, ys = [load_cached_results('results/robust_resnet50_l2_eps3_ftrs{}.pkl'.format('_no_relu' if no_relu else ''))[split][x] for x in ['ftrs', 'ys']]

    t = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor()])
    dset = datasets.ImageNet(root=_IMAGENET_ROOT, split='train', transform=t)

    s = int(np.round(np.sqrt(num_classes)))
    fig, axs = plt.subplots(s,s, figsize=(4*s, 4*s))
    _ = [axi.set_axis_off() for axi in axs.ravel()]

    idx_dict = dict()
    num_imgs = 300 if split == 'train' else 25

    for c in tqdm(sorted_by_cls):
        cls_idx = np.where(ys == c)[0]
        cls_ftrs = ftrs[cls_idx]
        idx_dict[c] = dict()

        for ftr_type in ['core', 'spurious']:
            ftr_set = sorted_by_cls[c][ftr_type]
            if len(ftr_set) == 0:
                continue
            z_scores = []
            for f in ftr_set:
                mean, std = cls_ftrs[:,f].mean(), cls_ftrs[:,f].std()
                z_scores.append((cls_ftrs[:,f] - mean) / std)
            avg_z_scores = np.average(z_scores, 0)
            img_ranks = np.argsort(-1*avg_z_scores)
            ranked_cls_idx = cls_idx[img_ranks]
            idx_dict[c][ftr_type] = ranked_cls_idx

    cache_results(f"results/new_img_rankings_by_idx{'_no_relu' if no_relu else ''}_{split}.pkl", idx_dict)

def view_spur_gap_samples(no_relu=True, num=10, random=True, ext=''):
    if not random:
        ncols=10
        f, axs = plt.subplots(36,ncols, figsize=(35,30))
    else:
        ncols=1; nrows=5
        f, axs = plt.subplots(nrows,ncols, figsize=(3,4))
        if ncols == 1:
            axs = axs.reshape(nrows,1)
        keep_idx = np.random.choice(357, nrows*ncols, replace=False)
    _ = [axi.set_axis_off() for axi in axs.ravel()]


    for split in ['val', 'train']:
        rankings = load_cached_results(f"results/new_img_rankings_by_idx{'_no_relu' if no_relu else ''}_{split}.pkl")
        t = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor()])
        dset = datasets.ImageNet(root=_IMAGENET_ROOT, split=split, transform=t)

        spurious_cs = [c for c in rankings if 'spurious' in rankings[c]]
        if random:
            spurious_cs = [c for j,c in enumerate(spurious_cs) if j in keep_idx]
        rankings = dict({c:rankings[c] for c in spurious_cs})

        for j, c in enumerate(tqdm(rankings)):
            bot = [dset[rankings[c]['spurious'][i]][0] for i in range(-1,num*-1-1, -1)]
            top = [dset[rankings[c]['spurious'][i]][0] for i in range(num)]

            grid = make_grid(top+bot, nrow=num).swapaxes(0,1).swapaxes(1,2)
            axs[j // ncols, j % ncols].imshow(grid)
            axs[j // ncols, j % ncols].set_title(imagenet_classes[c].title(), fontsize=4)

        f.savefig(f'plots/new_spurious_gap_egs{ext}_{split}.jpg', dpi=300, bbox_inches='tight', pad_inches=0.1)

def view_extreme_examples_for_class(cls_idx, num=5, no_relu=True, split='train'):
    rankings = load_cached_results(f"results/new_img_rankings_by_idx{'_no_relu' if no_relu else ''}_{split}.pkl")
    t = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor()])
    dset = datasets.ImageNet(root=_IMAGENET_ROOT, split=split, transform=t)
    
    keep_idx = []
    for c in cls_idx:
        if 'spurious' not in rankings[c]:
            print(f'No spuriosity rankings for class {imagenet_classes[c]}')
        else:
            keep_idx.append(c)
    cls_idx = keep_idx; print('Keeping ', cls_idx)
    
    f, axs = plt.subplots(len(cls_idx), 1, figsize=(2*num, len(cls_idx)*1.25))

    middle = torch.ones((3,224,224))
    middle[:,106:118, 50:62] = 0
    middle[:,106:118, 106:118] = 0
    middle[:,106:118, 162:174] = 0
    for ax, c in zip(axs, cls_idx):

        # using new rankings
        imgs = [dset[rankings[c]['spurious'][-1*i]][0] for i in range(1,num+1)] + \
               [middle] + [dset[rankings[c]['spurious'][i]][0] for i in range(num-1, -1, -1)]

        grid = make_grid(imgs, nrow=2*num+1, padding=0)

        ax.imshow(grid.swapaxes(0,1).swapaxes(1,2))
        if c == cls_idx[0]:
            ax.set_title('$\\leftarrow$Low Spuriosity'+ f'Class: {imagenet_classes[c].upper()}'.center(100) + 'High Spuriosity$\\rightarrow$', fontsize=10)
        else:
            ax.set_title(f'Class: {imagenet_classes[c].upper()}', fontsize=10)
        ax.set_axis_off()
    f.subplots_adjust(hspace=0.02)
    f.tight_layout(); f.savefig('plots/more_clarifying_egs.jpg', dpi=300)#, bbox_inches='tight', pad_inches=0.02)

if __name__ == '__main__':
    # trainer = Trainer(trial=2, mkey='finetune', epochs=20)
    # trainer = Trainer(mkey='finetune', dset='celebA', epochs=20)
    # trainer.train()
    # for dset in ['waterbirds', 'celebA']:
        # obtain_features(dset=dset, no_relu=True)
    #     find_important_ftrs(dset=dset)
        # visualize_important_ftrs(dset=dset)

    # view_and_assess_subpopulation_discovery(dset='celebA')
    cool_examples()

    # obtain_all_imgnet_robust_ftrs(mkey='simclr_resnet50x1')
    # obtain_all_imgnet_robust_ftrs(no_relu=False)
    # obtain_all_imgnet_robust_ftrs(no_relu=False)
    # imagenet_minority_examples(no_relu=True, split='val')
    # imagenet_minority_examples(no_relu=True)
    # imagenet_minority_examples(no_relu=False)

    # view_extreme_examples_for_class([479, 193, 626, 422], split='train', num=4)
    # view_extreme_examples_for_class([479, 193, 391, 815, 526, 982, 722, 973, 856, 925, 626, 422, 693, 893, 545, 20, 832, 306, 317, 319, 320, 321, 322, 323, 324, 325, 522, 101, 386, 355, 360, 202], split='train')
    # view_extreme_examples_for_class([479, 673, 815, 962, 526, 834, 120, 774, 453, 422], split='train')
    # view_extreme_examples_for_class([479, 673, 453, 422], split='train', num=3)
    # view_extreme_examples_for_class([795, 321, 355], split='train')
    # view_extreme_examples_for_class([360, 306, 626], split='train', num=4)
    # view_extreme_examples_for_class([815,962,526], split='train', num=5)
    # view_extreme_examples_for_class([834, 120, 774], split='train', num=5)
    # view_spur_gap_samples() 
    # save_a_bunch_of_random_egs()

    # for arl
    # view_extreme_examples_for_class([895, 744, 403, 405, 407, 465, 471, 488, 489, 507, 510, 652, 734, 821, 833], num=7)
