from PIL import Image
import glob
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import json
from tqdm import tqdm
import torch
from tools.datasets import ImageNet
from tools.model_utils import make_and_restore_model, eval_model
from model_eval import load_cached_results, cache_results
from robustness import model_utils as robust_model_utils
from model_eval import get_arch

input_size=224

class ObjectNet(Dataset):
    def __init__(self, root='/REDACTED/data/objectnet/objectnet-1.0/', 
                    #    transform=transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()]),
                       transform=transforms.Compose([transforms.Resize(input_size),transforms.CenterCrop(input_size), transforms.ToTensor()]),
                       normalize=None, img_format='png'):
        self.root = root
        self.transform = transform
        self.normalize = normalize
        files = glob.glob(root+"/**/*."+img_format, recursive=True)
        self.pathDict = {}
        for f in files:
            self.pathDict[f.split("/")[-1]] = f
        self.imgs = list(self.pathDict.keys())
        self.loader = self.pil_loader
        with open(self.root+'mappings/folder_to_onet_id.json', 'r') as f:
            self.folder_to_onet_id = json.load(f)

    def __getitem__(self, index):
        """
        Get an image and its label.
        Args:
            index (int): Index
        Returns:
            tuple: Tuple (image, onet_id). onet_id is the ID of the objectnet class (0 to 112)
        """
        img, onet_id = self.getImage(index)
        img = self.transform(img)
        if self.normalize is not None:
            img = self.normalize(img)

        return img, onet_id

    def getImage(self, index):
        """
        Load the image and its label.
        Args:
            index (int): Index
        Return:
            tuple: Tuple (image, target). target is the image file name
        """
        filepath = self.pathDict[self.imgs[index]]
        img = self.loader(filepath)

        # crop out red border
        width, height = img.size
        cropArea = (2, 2, width-2, height-2)
        img = img.crop(cropArea)

        # map folder name to objectnet id
        folder = filepath.split('/')[-2]
        onet_id = self.folder_to_onet_id[folder]
        return (img, onet_id)

    def __len__(self):
        """Get the number of ObjectNet images to load."""
        return len(self.imgs)

    def pil_loader(self, path):
        """Pil image loader."""
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')



####### EVALUATION

def eval_on_objectnet(model, batch_size=64, num_workers=16,
                      normalize=transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])):
    dset = ObjectNet(normalize=normalize)
    loader = DataLoader(dset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    with open('/REDACTED/data/objectnet/objectnet-1.0/mappings/inet_id_to_onet_id.json', 'r') as f:
        inet_id_to_onet_id = json.load(f)
    inet_id_to_onet_id = dict({int(k):v for k,v in inet_id_to_onet_id.items()})

    model = model.to(device).eval()
    # model = torch.nn.DataParallel(model)

    cc, cnt = 0, 0
    for imgs, onet_ids in tqdm(loader):
        imgs = imgs.to(device)
        # onet_ids = onet_ids

        with torch.no_grad():
            inet_preds = model(imgs).argmax(1).detach().cpu()
            # _, inet_preds = model(imgs).detach().cpu().topk(5,1)

        onet_preds = torch.LongTensor([-1 if p.item() not in inet_id_to_onet_id 
                                          else inet_id_to_onet_id[p.item()] for p in inet_preds])

        cc += (onet_preds == onet_ids).sum()
        cnt += imgs.shape[0]

        # if cnt > 250:
        #     break

    return 100. * cc / cnt

# from torchvision.models import resnet50, resnet18
# net = resnet18(pretrained=True)
# print('Pretrained Resnet18 objectnet acc: ', eval_on_objectnet(net, num_workers=8))
# net = resnet50(pretrained=True)
# print('Pretrained Resnet50 objectnet acc: ', eval_on_objectnet(net, num_workers=8))
# # Output: Top 1 accuracy of 25.5287 for Resnet50, 17.4 for Resnet18

#### Robust Model Eval


l2_epsilons = [0,3]#[0, 0.25, 0.5, 1, 3, 5]
linf_epsilons = []#[0, 0.5, 1.0, 2.0, 4.0, 8.0]

train_ds = ImageNet('/tmp')
root = '/REDACTED/dcr_models/pretrained-robust/'
results_path = 'results/objectnet_no_norm.pkl'
# for arch in ['wide_resnet50_2']:#'resnet18', 'resnet50']:
for arch in ['shufflenet', 'mobilenet', 'vgg16_bn', 'resnext50_32x4d', 'densenet']:
    for adv_train_norm, epsilons in zip(['l2', 'linf'], [l2_epsilons, linf_epsilons]):
        for adv_train_eps in epsilons:
            mkey = f'{arch}_{adv_train_norm}_eps{adv_train_eps}'
            if adv_train_eps == 0:
                mkey = f'{arch}_l2_eps{adv_train_eps}'
            results = load_cached_results(results_path)
            if mkey not in results:
                # arch_arg = models.wide_resnet50_2() if 'wide' in arch else arch
                arch_arg, acf = get_arch(arch)
                net, _ = robust_model_utils.make_and_restore_model(arch=arch_arg, dataset=train_ds, 
                                resume_path=f'{root}/{mkey}.ckpt', add_custom_forward=acf)
                net = net.model.eval()
                objectnet_acc = eval_on_objectnet(net, normalize=None)
                results[mkey] = objectnet_acc
                cache_results(results_path, results)

            print('Model: {:20s}, ObjectNet Accuracy: {:.3f}'.format(mkey, results[mkey]))
            
