import torch
from torchvision import models, datasets, transforms
from my_utils import *
from tqdm import tqdm
import time
import datetime
import timm
import resnet_wider
import vits
import os
from ood_eval import *
from paths import *

def prep_loader(mode='bottom', k=100, split='train', normalize=True):
    '''
    Returns loader for ImageNet training images filtered based on spuriosity rank.
    Mode can be bottom, mid, top (going from lowest to highest spuriosity), or random (baseline)
    k determines the number of images per class.
    '''
    rankings = load_cached_results(f"results/new_img_rankings_by_idx_no_relu_{split}.pkl")
    transform_list = [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
    if normalize:
        transform_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
    t = transforms.Compose(transform_list)
    dset = datasets.ImageNet(root=_IMAGENET_ROOT, split=split, transform=t)

    spur_cls_idx = [c for c in rankings if 'spurious' in rankings[c]]
    img_idx = []
    for c in range(1000):
        if c in spur_cls_idx:
            if 'bottom' in mode:
                img_idx.extend(rankings[c]['spurious'][-1*k:])
            if 'top' in mode:
                img_idx.extend(rankings[c]['spurious'][:k])
            if 'middle' in mode:
                mid_pt = len(rankings[c]['spurious']) // 2
                img_idx.extend(rankings[c]['spurious'][(mid_pt - k//2):(mid_pt+k //2)])
            if 'random' in mode:
                idx = np.arange(len(rankings[c]['spurious']))
                np.random.shuffle(idx)
                img_idx.extend(rankings[c]['spurious'][idx[:k]])
            if 'full' in mode:
                img_idx.extend(rankings[c]['spurious'])
        else:
            if 'for_gap' not in mode:
                idx = np.arange(len(rankings[c]['core']))
                np.random.shuffle(idx)
                img_idx.extend(rankings[c]['core'][idx[:k]])
    
    dset.samples = [dset.samples[i] for i in img_idx]
    loader = torch.utils.data.DataLoader(dset, shuffle=True, batch_size=32, num_workers=16)
    return loader

ftr_batch_size=256
# def cache_features(ftr_encoder, loader):
def cache_features(model, loader):
    cached_ftrs, ys = [], []
    # ftr_encoder = ftr_encoder.eval().cuda()
    model = model.eval().cuda()
    for x,y in tqdm(loader):
        with torch.no_grad():
            cached_ftrs.extend(model.forward_features(x.cuda()).flatten(1).detach().cpu().numpy())
            ys.extend(y)
    cached_ftrs, ys = [np.array(x) for x in [cached_ftrs, ys]]
    model = model.cpu()
    return cached_ftrs, ys

def compute_spurious_gap(head, val_cached, high_spur_cached, low_spur_cached):
    accs = []
    for (ftrs, ys) in [val_cached, high_spur_cached, low_spur_cached]:
        cc = 0
        for i in range(0, ftrs.shape[0], ftr_batch_size):
            batch_ftrs = torch.tensor(ftrs[i:i+ftr_batch_size]).cuda()
            batch_ys = torch.LongTensor(ys[i:i+ftr_batch_size]).cuda()

            with torch.no_grad():
                cc += (head(batch_ftrs).argmax(1) == batch_ys).sum().item()

        accs.append(100 * cc / ftrs.shape[0])
    return accs[0], accs[1] - accs[2]

# def tune_head(ftr_encoder, head, mode='bottom', k=100, epochs=100):
def tune_head(mkey, mode='bottom', k=100, epochs=200):
    '''
    *OLD* We start with a pretrained model, which is passed as a ftr_encoder and linear classification head.
    *NEW* We pass a pretrained model that has .forward_features function to go from imgs to ftrs, and .head which is the linear cls head.
    
    We also are given a loader, which will typically only have a subset of images filtered by spuriosity. 
    '''
    start = time.time()

    model, normalize = load_model(mkey)

    print('Preparing data loaders.')
    # Loader for training
    loader = prep_loader(mode=mode, k=k, split='train', normalize=normalize)
    # Loaders for evaluation
    full_val_loader = prep_loader(k=50, split='val', mode='full', normalize=normalize)
    high_spur_val_loader = prep_loader(mode='top__for_gap', k=10, split='val', normalize=normalize)
    low_spur_val_loader = prep_loader(mode='bottom__for_gap', k=10, split='val', normalize=normalize)

    print('Data Loaders prepared. Now Caching Features.')
    # Let's cache features to speed up training.
    cached_ftrs, ys = cache_features(model, loader)
    val_cached = cache_features(model, full_val_loader)
    high_spur_cached = cache_features(model, high_spur_val_loader)
    low_spur_cached = cache_features(model, low_spur_val_loader)
    print('Caching Complete. Commencing Finetuning of Linear Classification Head.')

    # Now we are ready for finetuning (after initializing optimizer of course)
    head = model.head.cuda()
    params = list(head.parameters())
    lr = 0.01 if 'resnet' in mkey else 0.1
    optimizer = torch.optim.SGD(params, momentum=0.9, lr=lr) 
    criterion = torch.nn.CrossEntropyLoss().cuda()
    best_spur_gap = 100
    best_head_state = None
    best_val_acc = 0
    val_acc, spur_gap = compute_spurious_gap(head, val_cached, high_spur_cached, low_spur_cached)
    spur_gaps = [spur_gap]; val_accs = [val_acc]
    print(f'Before Finetuning...Val Acc: {val_acc:.2f}%, Spurious Gap: {spur_gap:.2f}%')
    for epoch in range(epochs):
        head = head.train().cuda()
        cc, ctr, running_loss = 0, 0, 0
        for i in range(0, cached_ftrs.shape[0], ftr_batch_size):
            batch_ftrs = torch.tensor(cached_ftrs[i:i+ftr_batch_size]).cuda()
            batch_ys = torch.LongTensor(ys[i:i+ftr_batch_size]).cuda()
            logits = head(batch_ftrs)
            loss = criterion(logits, batch_ys)
            ctr += batch_ys.shape[0]
            cc += (logits.argmax(1) == batch_ys).sum().item()
            running_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        val_acc, spur_gap = compute_spurious_gap(head, val_cached, high_spur_cached, low_spur_cached)
        print(f'Epoch: {epoch}/{epochs}, Train Loss: {running_loss/ctr:.3f}, Train Acc: {100*cc/ctr:.2f}%, Val Acc: {val_acc:.2f}%, Spurious Gap: {spur_gap:.2f}%')
    
        if np.abs(spur_gap) < best_spur_gap:
            best_head_state = head.cpu().state_dict()
            best_spur_gap = np.abs(spur_gap)
            best_val_acc = val_acc
        #     best_epoch = epoch

        spur_gaps.append(spur_gap)
        val_accs.append(val_acc)    

        if spur_gap < 5:
            break
    
    elapsed = str(datetime.timedelta(seconds=int(time.time()-start)))
    # save_dict = dict({'state': best_head_state, 'spur_gap': best_spur_gap, 'val_acc': best_val_acc, 'epoch': epoch, 'elapsed': elapsed})
    save_dict = dict({'state': best_head_state, 'spur_gap': best_spur_gap, 'val_acc': best_val_acc, 'spur_gaps':spur_gaps, 'val_accs':val_accs, 'elapsed': elapsed})
    # torch.save(save_dict, f'ft_heads/resnet50/{mode}_{k}.pth')
    torch.save(save_dict, f'ft_heads3/{mkey}/{mode}_{k}.pth')
    print(f'Saved head with spurious gap of {best_spur_gap:.2f}% and validation accuracy of {best_val_acc:.2f}% to ft_heads/{mkey}/{mode}_{k}.pth. Time elapsed: {elapsed}.')

def load_model(mkey, head_args=None):
    normalize=True
    print(f'loading model {mkey} with args {head_args}')
    if 'resnet50' in mkey:
        if mkey == 'simclr_resnet50':
            model = resnet_wider.__dict__['resnet50x1']()
            ckpt = torch.load(_SIMCLR_ROOT)
            model.load_state_dict(ckpt['state_dict'])
            normalize=False
        elif 'robust' in mkey:
            model = load_robust_resnet(mkey)
        else:
            model = models.resnet50(pretrained=True).eval().cuda()
        model.forward_features = torch.nn.Sequential(*list(model.children())[:-1])
        model.head = model.fc
    # the following two models already have .forward_features and .head implemented
    elif mkey == 'deit_small':
        model = timm.create_model('deit_small_patch16_224', pretrained=True).eval().cuda()
    elif mkey == 'moco_vit-s':
        model = vits.__dict__['vit_small']()
        ckpt = torch.load(_MOCO_ROOT)
        og_states = ckpt['state_dict']
        state_dict = dict({k[len('module.'):]:og_states[k] for k in og_states})
        model.load_state_dict(state_dict)

    if head_args is not None:
        mode, k = head_args
        head_state_dict = torch.load(f'ft_heads3/{mkey}/{mode}_{k}.pth')['state']
        model.head = model.head.cpu()
        model.head.load_state_dict(head_state_dict)
        model.head = model.head.cuda()

    return model.eval().cuda(), normalize

def compute_accuracy_by_rank(mkey, mode, k):
    model, normalize = load_model(mkey, head_args=(mode, k))
    # if mkey == 'resnet50':
    #     model = models.resnet50(pretrained=True).eval().cuda()
    #     model.forward_features = torch.nn.Sequential(*list(model.children())[:-1])
    #     model.head = model.fc
    # elif mkey == 'deit_small':
    #     model = timm.create_model('deit_small_patch16_224', pretrained=True).eval().cuda()
    # head = model.head.cpu()
    
    save_path = f'ft_heads3/{mkey}/{mode}_{k}.pth'
    save_dict = torch.load(save_path)
    if True: #'accs_by_rank' not in save_dict:
        # head_tuned_state = save_dict['state']
        # head.load_state_dict(head_tuned_state)
        # head = head.cuda()

        # For each ranking, we'll make it's own loader
        rankings = load_cached_results(f"results/new_img_rankings_by_idx_no_relu_val.pkl")
        transform_list = [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
        if normalize:
            transform_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
        t = transforms.Compose(transform_list)
        dset = datasets.ImageNet(root=_IMAGENET_ROOT, split='val', transform=t)
        og_samples = dset.samples.copy()

        spur_cls_idx = [c for c in rankings if 'spurious' in rankings[c]]
        accs = []
        for rank in tqdm(range(50)):
            img_idx = np.array([rankings[c]['spurious'][rank] for c in spur_cls_idx])
            dset.samples = [og_samples[i] for i in img_idx]
            loader = torch.utils.data.DataLoader(dset, shuffle=False, batch_size=32, num_workers=16)

            cc = 0
            for x,y in loader:
                with torch.no_grad():
                    ftrs = model.forward_features(x.cuda()).flatten(1)
                    cc += (model.head(ftrs).argmax(1) == y.cuda()).sum().item()
            accs.append(cc)# / len(spur_cls_idx))
        
        save_dict['accs_by_rank'] = np.array(accs)
        save_dict['num_spur_classes'] = len(spur_cls_idx)
        torch.save(save_dict, save_path)
    accs = save_dict['accs_by_rank']
    smoothed_accs = np.array([np.nanmean(accs[max(i-5,0):min(i+5, 49)]) for i in range(50)])
    
    print(f'{mkey}_{mode}_{k}')
    print('Accuracy by Rank: ', [int(a*100 / save_dict['num_spur_classes']) for a in accs])
    print('Smoothed Accuracy by Rank: ', [int(a*100 / save_dict['num_spur_classes']) for a in smoothed_accs])
    

### OOD eval
def eval_ood():
    results_path = 'results/ft_heads_ood.pkl'
    results = load_cached_results(results_path)

    mkeys = ['resnet50', 'robust_resnet50_linf_eps4.0', 'deit_small', 'moco_vit-s', 'simclr_resnet50']
    for mkey in mkeys:
        if mkey not in results:
            results[mkey] = dict()
        for head_args in ['bottom_100', 'random_100']:
            if head_args not in results[mkey]:
                results[mkey][head_args] = dict()
            for ood_dset, eval_fn in zip(['objectnet', 'r', 'sketch'], [eval_on_objectnet, eval_on_imagenet_r, eval_on_sketch]):
            # for ood_dset, eval_fn in zip(['r'], [eval_on_imagenet_r]):
                # print(results[mkey][head_args].keys())
                if ood_dset not in results[mkey][head_args]:
                    mode, k = head_args.split('_')
                    model, normalize = load_model(mkey, (mode, int(k)))
                    # print(model)
                    acc, _ = eval_fn(model, normalize)
                    # acc=420
                    results[mkey][head_args][ood_dset] = acc
                    cache_results(results_path, results)
                    # print(results)

                print(f'Model: {mkey}, Head: {head_args}, Dset: {ood_dset}, Accuracy: {100*results[mkey][head_args][ood_dset]:.2f}%')


if __name__ == '__main__':
    # for mode in ['bottom', 'bottom_top', 'top', 'random']:
    # for mkey in ['resnet50', 'deit_small']:
    # for mkey in ['simclr_resnet50', 'moco_vit-s', 'robust_resnet50_linf_eps4', 'resnet', 'deit_small']:

    # eval_ood()


    # for mkey in ['simclr_resnet50', 'robust_resnet50_linf_eps4.0', 'resnet50']:
    for mkey in ['resnet50', 'robust_resnet50_linf_eps4.0', 'deit_small', 'moco_vit-s', 'simclr_resnet50']: 
    # for mkey in ['moco_vit-s']: # 'deit_small', 'moco_vit-s', 
        # tune_head(mkey, 'bottom', 100)
        # tune_head(mkey, 'random', 100)
        compute_accuracy_by_rank(mkey, 'bottom', 100)
        compute_accuracy_by_rank(mkey, 'random', 100)


    # for mkey in ['robust_resnet50_linf_eps4.0', 'resnet', 'deit_small']:
    #     for mode in ['bottom', 'random', 'top']:
    #         for k in [100, 200, 250, 300]:
    #         # for k in [10, 50, 100, 150, 200, 250]:
    # #         # model = models.resnet50(pretrained=True)
    # #         # model.forward_features = torch.nn.Sequential(*list(model.children())[:-1])
    # #         # model.head = model.fc

    #             # model = timm.create_model('deit_small_patch16_224', pretrained=True)
    #             if not os.path.exists(f'ft_heads/{mkey}/{mode}_{k}.pth'):
    #                 os.makedirs(f'ft_heads/{mkey}', exist_ok=True)
    #                 tune_head(mkey, mode, k)
    #             # accuracy_by_rank(mode, k, mkey)
    

