# -*- coding: utf-8 -*-

import os
from PIL import Image
import os.path
import time
import torch
import torchvision.datasets as dset
import torchvision.transforms as trn
import torch.utils.data as data
import numpy as np

from PIL import Image
from .utils import noisify_y
# /////////////// Data Loader ///////////////


IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']


def is_image_file(filename):
    """Checks if a file is an image.
    Args:
        filename (string): path to a file
    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)


def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


def make_dataset(dir, class_to_idx, num_classes=1000):
    images = []
    dir = os.path.expanduser(dir)
    for target in sorted(os.listdir(dir)):
        if class_to_idx[target] >= num_classes:
            continue
        # if class_to_idx[target] not in [0,10,24,29,30,32,37,50,52,66]:
        #     continue
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def default_loader(path):
    # from torchvision import get_image_backend
    # if get_image_backend() == 'accimage':
    #     return accimage_loader(path)
    # else:
    return pil_loader(path)


class ImagenetNoise(data.Dataset):
    def __init__(self,
        train=True,
        xnoise_rate=0, xnoise_type='contrast', xnoise_arg=5, 
        ynoise_type='symmetric', ynoise_rate=0,
        transform=None, 
        target_transform=None,
        loader=default_loader, 
        num_classes=1000, random_state=0):

        if train:
            clean_root = '/data/LargeData/Large/ImageNet/train'
            # clean_root = '/data/cuipeng/ImageNet_Clean_10/train'
            classes, class_to_idx = find_classes(clean_root)
            imgs = make_dataset(clean_root, class_to_idx, num_classes)
            if xnoise_rate > 0:
                if 'mix' in xnoise_type:
                    noise_name = xnoise_type
                else:
                    noise_name = f'{xnoise_type}_{xnoise_arg}'
                noise_root = '/data/leyang/ImageNet_C/' + noise_name
                if not os.path.exists(noise_root):
                    print('Noisy image not found! Please run make_imagenet_c.py first.')
                    exit(1)
                noise_imgs = make_dataset(noise_root, class_to_idx, num_classes)
                self.noise_imgs = noise_imgs
        else:
            clean_root = '/data/LargeData/Large/ImageNet/val'
            classes, class_to_idx = find_classes(clean_root)
            imgs = make_dataset(clean_root, class_to_idx, num_classes)
        
        self.num_classes = num_classes
        self.train = train
        self.imgs = imgs
        self.targets = np.asarray([img[1] for img in imgs])
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.idx_to_class = {v: k for k, v in class_to_idx.items()}
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

        self.xnoise_rate = xnoise_rate
        seed = np.random.RandomState(random_state)
        selected = seed.binomial(1, xnoise_rate, size=(1281167,))
        self.xnoisy_or_not = (selected == 1)[:len(imgs)]
        # print('Image Count:', len(self.imgs))
        if ynoise_rate > 0.0:
            self.targets = np.asarray([[int(self.targets[i])] for i in range(len(self.targets))])
            self.noise_targets = noisify_y(train_labels=self.targets, noise_type=ynoise_type, noise_rate=ynoise_rate, random_state=random_state, nb_classes=self.num_classes)
            self.noise_targets = np.asarray([i[0] for i in self.noise_targets])
            self.targets =  np.asarray([i[0] for i in self.targets])
            self.ynoisy_or_not = (self.targets != self.noise_targets)
        else:
            self.noise_targets = self.targets
            self.ynoisy_or_not = np.zeros(len(imgs)).astype(np.int)
        # self.report_noise()

    def get_noise(self): 
        xy_noise = np.logical_and(self.xnoisy_or_not, self.ynoisy_or_not)
        x_noise = np.logical_and(self.xnoisy_or_not, ~self.ynoisy_or_not)
        y_noise = np.logical_and(~self.xnoisy_or_not, self.ynoisy_or_not)
        clean = np.logical_and(~self.xnoisy_or_not, ~self.ynoisy_or_not)
        return {
            'xy_noise': xy_noise,
            'x_noise': x_noise,
            'y_noise': y_noise,
            'clean': clean,
            'xnoisy': self.xnoisy_or_not,
            'ynoisy': self.ynoisy_or_not
        }
    
    def report_noise(self):
        noise_stat = self.get_noise()
        print('Noise Stat:')
        for key, val in noise_stat.items():
            print(key, val.sum())

    def __getitem__(self, index):

        xnoisy = self.xnoisy_or_not[index]
        noise_tar = self.noise_targets[index]
        if xnoisy:
            path, true_tar = self.noise_imgs[index]
        else:
            path, true_tar = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        # if self.target_transform is not None:
        #     target = self.target_transform(target)
        if self.train:
            return (index, (img, xnoisy), (noise_tar, true_tar))
        else:
            return (index, img, true_tar)

        

    def __len__(self):
        return len(self.imgs)
    
    def filenames(self, indices=[], basename=False):
        if indices:
            if basename:
                return [os.path.basename(self.imgs[i][0]) for i in indices]
            else:
                return [self.imgs[i][0] for i in indices]
        else:
            if basename:
                return [os.path.basename(x[0]) for x in self.imgs]
            else:
                return [x[0] for x in self.imgs]

class ImagenetNoiseMulti(data.Dataset):
    def __init__(self,
        train=True,
        xnoise_rates=[], xnoise_types=[], xnoise_args=[], 
        ynoise_type='symmetric', ynoise_rate=0,
        transform=None, 
        target_transform=None,
        loader=default_loader, 
        num_classes=1000, random_state=0):

        self.types = ['clean'] + xnoise_types
        self.rates = [1-sum(xnoise_rates)] + xnoise_rates
        self.imgs = []
        if train:
            for idx, xnoise_type in enumerate(self.types):
                if xnoise_type == 'clean':
                    # clean_root = '/data/LargeData/Large/ImageNet/train'
                    clean_root = '/data/cuipeng/ImageNet_Clean/train'
                    classes, class_to_idx = find_classes(clean_root)
                    imgs = make_dataset(clean_root, class_to_idx, num_classes)
                    self.imgs.append(imgs)
                elif self.rates[idx] > 0:
                    noise_name = f'{xnoise_type}_{xnoise_args[idx-1]}'
                    noise_root = '/data/leyang/ImageNet_C/' + noise_name
                    if not os.path.exists(noise_root):
                        print('Noisy image not found! Please run make_imagenet_c.py first.')
                        exit(1)
                    noise_imgs = make_dataset(noise_root, class_to_idx, num_classes)
                    self.imgs.append(noise_imgs)
        else:
            clean_root = '/data/LargeData/Large/ImageNet/val'
            classes, class_to_idx = find_classes(clean_root)
            imgs = make_dataset(clean_root, class_to_idx, num_classes)
            self.imgs.append(imgs)
        
        self.num_classes = num_classes
        self.train = train
        self.targets = np.asarray([img[1] for img in self.imgs[0]])
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.idx_to_class = {v: k for k, v in class_to_idx.items()}
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

        seed = np.random.RandomState(random_state)
        self.noise_select = seed.multinomial(1, self.rates, size=(len(self.imgs[0]),))
        self.noise_select = np.argmax(self.noise_select, axis=1)
        self.xnoisy_or_not = (self.noise_select > 0)
        # print('Image Count:', len(self.imgs))
        if ynoise_rate > 0.0:
            self.targets = np.asarray([[int(self.targets[i])] for i in range(len(self.targets))])
            self.noise_targets = noisify_y(train_labels=self.targets, noise_type=ynoise_type, noise_rate=ynoise_rate, random_state=random_state, nb_classes=self.num_classes)
            self.noise_targets = np.asarray([i[0] for i in self.noise_targets])
            self.targets =  np.asarray([i[0] for i in self.targets])
            self.ynoisy_or_not = (self.targets != self.noise_targets)
        else:
            self.noise_targets = self.targets
            self.ynoisy_or_not = np.zeros(len(imgs)).astype(np.int)
        # self.report_noise()

    def get_noise(self): 
        xy_noise = np.logical_and(self.xnoisy_or_not, self.ynoisy_or_not)
        x_noise = np.logical_and(self.xnoisy_or_not, ~self.ynoisy_or_not)
        y_noise = np.logical_and(~self.xnoisy_or_not, self.ynoisy_or_not)
        clean = np.logical_and(~self.xnoisy_or_not, ~self.ynoisy_or_not)
        return {
            'xy_noise': xy_noise,
            'x_noise': x_noise,
            'y_noise': y_noise,
            'clean': clean,
            'xnoisy': self.xnoisy_or_not,
            'ynoisy': self.ynoisy_or_not
        }
    
    def report_noise(self):
        noise_stat = self.get_noise()
        print('Noise Stat:')
        for key, val in noise_stat.items():
            print(key, val.sum())
        print('Xnoise pattern:')
        print(self.types)
        print(self.rates)

    def __getitem__(self, index):
        noise_idx = self.noise_select[index]
        noise_tar = self.noise_targets[index]
        path, true_tar = self.imgs[noise_idx][index]
        img = self.loader(path)
        xnoisy = self.xnoisy_or_not[index]
        if self.transform is not None:
            img = self.transform(img)
        # if self.target_transform is not None:
        #     target = self.target_transform(target)
        if self.train:
            return (index, (img, xnoisy), (noise_tar, true_tar))
        else:
            return img, true_tar


    def __len__(self):
        return len(self.imgs[0])
