import pdb
from collections import defaultdict
import json
import os
import pickle
import zipfile

import numpy as np
from PIL import Image, ImageFile

import torch
from torchvision import datasets as t_datasets
from tqdm import tqdm

import random
import glob
from dataset.dataset_util import get_paths

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root, class_to_idx, transform=None, ret_path=False):
        """
        :param root: Dataset root. Should follow the structure class1/0.jpg...n.jpg, class2/0.jpg...n.jpg
        :param class_to_idx: dictionary mapping the classnames to integers.
        :param transform:
        :param ret_path: boolean indicating whether to return the image path or not (useful for KNN for plotting nearest neighbors)
        """

        self.transform = transform
        self.label_to_idx = class_to_idx

        self.paths = []
        self.labels = []
        for cls in class_to_idx:
            cls_paths = get_paths(os.path.join(root, cls))
            self.paths += cls_paths
            self.labels += [self.label_to_idx[cls] for _ in cls_paths]

        self.ret_path = ret_path

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        im_path, label = self.paths[idx], self.labels[idx]
        img = pil_loader(im_path)

        if self.transform is not None:
            img = self.transform(img)
        if not self.ret_path:
            return img, label
        else:
            return img, label, im_path


