import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import torch.utils.data as data
from PIL import Image
import json
from torch.utils.data import Dataset


class CelebaDataset(Dataset):
    """Custom Dataset for loading CelebA face images"""


    def __init__(self, csv_path, img_dir, transform=None, mode='binary', binary_label='Big_Nose', sensitive_label = 'Male'):
        # df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        self.csv_path = csv_path
        # self.img_names = df.index.values
        # self.y = df.values
        self.transform = transform
        self.mode = mode
        self.binary_label = binary_label

        self.sensitive_label = sensitive_label

        # print('first row: ')
        # print(df.iloc[0])

        if self.mode == 'multi-label':
            df = pd.read_csv(csv_path, index_col=0)
            print('data dimensionality : ', df.shape)
            self.img_names = df.index.values
            self.targets = df.values
            ### Compute imratio for each label
            imratio_list = []
            print('self.targets.shape : ', self.targets.shape)
            for index in range(self.targets.shape[1]):
                row_ind = self.targets[index]
                unique, counts = np.unique(row_ind, return_counts=True)
                count_dict = dict(zip(unique, counts))
                print('count_dict : ', count_dict)
                try:
                    one_count = count_dict[1]
                except:
                    one_count = 0
                zero_count = count_dict[0]
                imratio = one_count / (one_count + zero_count)
                imratio_list.append(imratio)
                # print("Index = ", index, ", imratio = ", imratio)
            self.imratio_list = imratio_list
            print('imratio_list : ', imratio_list)

        elif self.mode == 'binary':
            df = pd.read_csv(csv_path, index_col=0)
            '''
            Binary Sensitive Label Values
            '''
            self.sensitives = df[self.sensitive_label].values.reshape(-1, 1)

            print('Original data dimensionality : ', df.shape)
            # df = df.loc[binary_label]
            for col in df.columns:
                if col != binary_label:
                    del df[col]

            # df = pd.read_csv(csv_path, index_col=0, usecols=['Big_Lips', 'Big_Nose'])

            print('Binary data dimensionality : ', df.shape)

            self.img_names = df.index.values
            self.targets = df.values
            ### Compute imratio for each label
            imratio_list = []
            # print('self.y.shape : ', self.y.shape)

            # for index in range(self.y.shape[1]):
            row_ind = self.targets
            unique, counts = np.unique(row_ind, return_counts=True)
            count_dict = dict(zip(unique, counts))
            # print('count_dict : ', count_dict)
            try:
                one_count = count_dict[1]
            except:
                one_count = 0
            zero_count = count_dict[0]
            imratio = one_count / (one_count + zero_count)
            imratio_list.append(imratio)
            # print("Index = ", index, ", imratio = ", imratio)
            self.imratio_list = imratio_list
            print('imratio_list : ', imratio_list)

            from collections import Counter
            print('labels : ', df.columns[0], Counter(self.targets.reshape(-1).tolist()))
            print('sensitive labels : ', self.sensitive_label, Counter(self.sensitives.reshape(-1).tolist()))

        else:
            raise 'Error in CelebA mode.'




    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index]))

        if self.transform is not None:
            img = self.transform(img)

        label = self.targets[index]

        i_sensitive_label = self.sensitives[index]
        # return index, self.img_names[index], img, label
        return img, label, i_sensitive_label, index


    def __len__(self):
        return self.targets.shape[0]
