import torch.utils.data as data
from PIL import Image
import os
import json
import numpy as np
import collections
import csv
import torch
from util.inat_classes import inat_category
from collections import Counter
import random
def default_loader(path):
    return Image.open(path).convert('RGB')
class IGNAT_Loader(data.Dataset):
    def __init__(self, root, ann_file, transform=None, target_transform=None,
                 loader=default_loader, is_train=True):
        print('Loading annotations from: ' + os.path.basename(ann_file))
        with open(ann_file) as data_file:
            ann_data = json.load(data_file)
        imgs = [aa['file_name'] for aa in ann_data['images']]
        im_ids = [aa['id'] for aa in ann_data['images']]
        if 'annotations' in ann_data.keys():
            classes = [aa['category_id'] for aa in ann_data['annotations']]
        else:
            classes = [0]*len(im_ids)
        idx_to_class = {cc['id']: cc['name'] for cc in ann_data['categories']}
        print('\t' + str(len(imgs)) + ' images')
        print('\t' + str(len(idx_to_class)) + ' classes')
        self.ids = im_ids   
        self.root = root
        self.imgs = imgs    
        self.classes = classes  
        self.idx_to_class = idx_to_class   
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        counter = Counter(self.classes)
        class_sorted = sorted(counter.items(),key=lambda x:x[1],reverse=True)
        class_sorted = [list(item) for item in class_sorted]
        ids_n = []
        imgs_n = []
        classes_n = []
        catogory_set = set()
        catogory_set.update(inat_category)
        for i in range(len(self.ids)):
            if self.classes[i] in catogory_set:
               ids_n.append(self.ids[i]) 
               imgs_n.append(self.imgs[i])
               classes_n.append(self.classes[i]) 
        self.ids = ids_n
        self.imgs = imgs_n
        self.classes = classes_n
        counter = Counter(classes_n)
        class_sorted = sorted(counter.items(),key=lambda x:x[1],reverse=True)
        print(class_sorted[0])
        print(class_sorted[-1])
        print(len(self.imgs))
        a = 1
        ids_train = []
        imgs_train = []
        classes_train = []
        ids_test = []
        imgs_test = []
        classes_test = []
        threshold = 30  
        count_for_thres = 0
        category_counter = dict()
        class_pre2now = dict()
        class_temp = 0
        for item in inat_category:
            category_counter[item] = 0
        for i in range(len(inat_category)):     
            class_pre2now[class_sorted[i][0]] = class_temp
            class_temp += 1
        del class_temp
        for i in range(len(self.classes)):
            if category_counter[self.classes[i]] < threshold:
                ids_test.append(self.ids[i]) 
                imgs_test.append(self.imgs[i]) 
                classes_test.append(self.classes[i]) 
            else:
                ids_train.append(self.ids[i]) 
                imgs_train.append(self.imgs[i]) 
                classes_train.append(self.classes[i]) 
            category_counter[self.classes[i]] += 1
        if is_train:
            self.ids = ids_train
            self.imgs = imgs_train
            self.classes = classes_train
        else:
            self.ids = ids_test
            self.imgs = imgs_test
            self.classes = classes_test
        for i in range(len(self.classes)):
            self.classes[i] = class_pre2now[self.classes[i]]
    def __getitem__(self, index):
        path = self.root + self.imgs[index]     
        target = self.classes[index]        
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target
    def __len__(self):
        return len(self.imgs)
def get_parser(dataset_name):
    def inat_parser(line, is_train=True):
        if is_train:
            user_id, image_id, class_id, _ = line
            return user_id, image_id, class_id
        else:
            image_id, class_id, _ = line
        return image_id, class_id
    def landmarks_parser(line, is_train=True):
        if is_train:
            user_id, image_id, class_id = line
            return user_id, image_id, class_id
        else:
            image_id, class_id = line
        return image_id, class_id
    parsers = {
        'inat': inat_parser,
        'landmarks': landmarks_parser,
        'cifar': landmarks_parser  
    }
    return parsers[dataset_name]
class IGNAT_Loader_User_120k_Train(data.Dataset):
    def __init__(self, root, ann_file, transform=None, target_transform=None,
                 loader=default_loader):
        print('Train file: %s' % ann_file)
        if not os.path.exists(ann_file):
            print('Error: file does not exist.')
            return
        parser = get_parser("inat")
        user_image_counter = collections.Counter()
        img_names = []
        labels = []
        with open(ann_file) as f:
            reader = csv.reader(f)
            next(reader)  
            for line in reader:
                user_id, image_id, class_id = parser(line, is_train=True)
                user_image_counter[user_id] += 1
                labels.append(torch.tensor(int(class_id)))
                labels.append(int(class_id))
        self.root = root
        self.imgs = img_names        
        self.classes = labels  
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
    def __getitem__(self, index):
        path = self.root + "/" + self.imgs[index] + ".jpg"
        target = self.classes[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target
    def __len__(self):
        return len(self.imgs)
class IGNAT_Loader_User_120k_Test(data.Dataset):
    def __init__(self, root, ann_file, transform=None, target_transform=None,
                 loader=default_loader):
        print('Train file: %s' % ann_file)
        if not os.path.exists(ann_file):
            print('Error: file does not exist.')
            return
        parser = get_parser("inat")
        user_image_counter = collections.Counter()
        img_names = []
        labels = []
        with open(ann_file) as f:
            reader = csv.reader(f)
            next(reader)  
            for line in reader:
                image_id, class_id = parser(line, is_train=False)
                img_names.append(image_id)        
                labels.append(torch.tensor(int(class_id)))
        self.root = root
        self.imgs = img_names        
        self.classes = labels  
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
    def __getitem__(self, index):
        path = self.root + "/" + self.imgs[index] + ".jpg"
        target = self.classes[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target
    def __len__(self):
        return len(self.imgs)