import os
import torch
import torchvision.datasets as datasets
import numpy as np
from PIL import Image
import sys
sys.path.append('..')

CIFAR10_NORM = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
CIFAR100_NORM = ((0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                 (0.2673342858792401, 0.2564384629170883, 0.27615047132568404))

CORRUPTION = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
              'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'gaussian_blur',
              'snow', 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform',
              'pixelate', 'jpeg_compression', 'spatter', 'saturate']


def _tolist(x):
    if isinstance(x, list):
        return x
    elif isinstance(x, (np.ndarray, torch.Tensor)):
        return x.tolist()
    else:
        raise TypeError

def load_corruption(data_dir, c, l, dataset):
    if dataset == 'cifar10':
        loc_dir = os.path.join(data_dir, 'CIFAR-10-C')
    elif dataset == 'cifar100':
        loc_dir = os.path.join(data_dir, 'CIFAR-100-C')
    elif dataset == 'imagenet':
        loc_dir = os.path.join(data_dir, 'ImageNet-C')

    if c not in CORRUPTION:
        raise RuntimeError("unknown corruption type")
    if dataset.startswith('cifar'):
        c_data = np.load(os.path.join(loc_dir, c + '.npy'))
        if l > 5 or l <= 0:
            raise RuntimeError("severity level out of range")
        # data = c_data[(l-1)*10000:l*10000].transpose((0, 3, 1, 2))  # total 50000 images, 0-10000 are severity level 1, ...
        data = c_data[(l - 1) * 10000:l * 10000]

        label = np.load(os.path.join(loc_dir, 'labels.npy'))
        label = label[(l - 1) * 10000:l * 10000]
    elif dataset.startswith('imagenet'):
        data = torch.load(os.path.join(loc_dir, c + '.pth')).numpy()
        label = torch.load(os.path.join(loc_dir, 'labels.pth')).numpy()
    return data, label


class SubCifar10C(datasets.CIFAR10):
    def __init__(self, info_path, corruption, level, transform=None, indices=None,
                 mode=None, probs=None, psl=None, return_idx=False):
        assert mode in ['warmup', 'eval_train', 'label', 'unlabel', 'test']
        super(SubCifar10C, self).__init__(info_path, train=False, transform=transform)
        if corruption in CORRUPTION:
            test_data, test_label = load_corruption(info_path, corruption, level, 'cifar10')
        self.data = test_data
        self.targets = test_label

        if indices is not None:
            self.data = self.data[indices]
            self.targets = list(self.targets[indices])

        if psl is not None:
            self.targets = _tolist(psl)
        self.mode = mode
        self.probs = probs  # GMM probs;
        self.return_idx = return_idx

    def __getitem__(self, index):
        if self.mode == 'warmup' or self.mode == 'eval_train' or self.mode == 'test':
            img, target = self.data[index], self.targets[index]
            img = Image.fromarray(img)
            img = self.transform(img)
            if self.return_idx:
                return img, target, index
            else:
                return img, target
        elif self.mode == 'label':
            img, target, prob = self.data[index], self.targets[index], self.probs[index]
            img = Image.fromarray(img)
            img = self.transform(img)
            return img, target, prob
        elif self.mode == 'unlabel':
            img = self.data[index]
            img = Image.fromarray(img)
            img = self.transform(img)
            return img
        else:
            raise ValueError


class SubCifar100C(datasets.CIFAR100):
    def __init__(self, info_path, corruption, level, transform=None, indices=None,
                 mode=None, probs=None, psl=None, return_idx=False):
        assert mode in ['warmup', 'eval_train', 'label', 'unlabel', 'test']
        super(SubCifar100C, self).__init__(info_path, train=False, transform=transform)
        if corruption in CORRUPTION:
            test_data, test_label = load_corruption(info_path, corruption, level, 'cifar100')
        self.data = test_data
        self.targets = test_label

        if indices is not None:
            self.data = self.data[indices]
            self.targets = list(self.targets[indices])

        if psl is not None:
            self.targets = _tolist(psl)
        self.mode = mode
        self.probs = probs  # GMM probs;
        self.return_idx = return_idx

    def __getitem__(self, index):
        if self.mode == 'warmup' or self.mode == 'eval_train' or self.mode == 'test':
            img, target = self.data[index], self.targets[index]
            img = Image.fromarray(img)
            img = self.transform(img)
            if self.return_idx:
                return img, target, index
            else:
                return img, target
        elif self.mode == 'label':
            img, target, prob = self.data[index], self.targets[index], self.probs[index]
            img = Image.fromarray(img)
            img = self.transform(img)
            return img, target, prob
        elif self.mode == 'unlabel':
            img = self.data[index]
            img = Image.fromarray(img)
            img = self.transform(img)
            return img
        else:
            raise ValueError
