from torchvision.datasets import MNIST
from src.utils import add_squares_to_images, filter_data_by_label, group_labels, add_background_to_images, add_one_square_to_images
from PIL import Image
from tqdm import tqdm
from glob import glob

import os
import torch

import numpy as np


class BiasedMNIST(MNIST):
    def __init__(
            self,
            class_labels_to_filter,
            old_to_new_label_mapping,
            data_ratio_to_inject_bias,
            square_number,
            bias_type,
            reverse=False,
            square_size=4,
            **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.train = kwargs['train']
        self.return_masked = False
        img_data_dir = os.path.join(
            kwargs['root'], 'BiasedMNIST', 'images', bias_type, str(data_ratio_to_inject_bias), 'train' if kwargs['train'] else 'test')
        if not (os.path.isdir(img_data_dir) and len(os.listdir(img_data_dir)) > 0):
            dataset_to_create = 'train' if kwargs['train'] else 'test'
            print(
                f"\n\nstart creating and saving {dataset_to_create} dataset of BiasedMnist\n\n")
            os.makedirs(img_data_dir, exist_ok=True)
            self.data, self.targets = filter_data_by_label(
                self.data, self.targets, class_labels_to_filter)
            self.targets = group_labels(
                self.targets, old_to_new_label_mapping)
            self.data = torch.unsqueeze(self.data, dim=1).repeat((1, 3, 1, 1))
            if bias_type == 'square':
                self.data = add_squares_to_images(
                    self.data, self.targets, data_ratio_to_inject_bias, square_number, reverse, square_size)
            elif bias_type == 'background':
                self.data = add_background_to_images(
                    self.data, self.targets, data_ratio_to_inject_bias, reverse
                )
            elif bias_type == 'one_square':
                self.data = add_one_square_to_images(
                    self.data, self.targets, data_ratio_to_inject_bias, reverse, square_size)
            for target in list(old_to_new_label_mapping.keys()):
                os.makedirs(os.path.join(
                    img_data_dir, str(target)), exist_ok=True)
            for id, (data, target) in enumerate(zip(self.data, self.targets)):
                Image.fromarray(data.permute(1, 2, 0).numpy().astype(np.uint8)).save(
                    os.path.join(img_data_dir, str(target.item()), f'{id}.png')
                )
            self.data = []
            self.targets = []
            print(
                f"\n\nfinished creating and saving {dataset_to_create} dataset of BiasedMnist\n\n")
        self.update_data([img_data_dir])

    def update_data(self, data_file_directories, masked_data_file_path=None):
        self.data_path = []
        self.masked_data_path = []
        self.targets = []
        for data_file_path in data_file_directories:
            data_classes = sorted(os.listdir(data_file_path))
            print("-"*10, f"indexing {'train' if self.train else 'test'} data", "-"*10)
            for data_class in tqdm(data_classes):
                try:
                    target = int(data_class)
                except:
                    continue
                class_image_file_paths = glob(
                    os.path.join(data_file_path, data_class, '*'))
                self.data_path += class_image_file_paths
                if masked_data_file_path is not None:
                    self.return_masked = True
                    masked_class_image_file_paths = sorted(glob(
                        os.path.join(masked_data_file_path, data_class, '*')))
                    self.masked_data_path += masked_class_image_file_paths
                self.targets += [target] * len(class_image_file_paths)

    def update_data_with_path_list(self, data_file_pathes):
        self.data_path = []
        self.masked_data_path = []
        self.targets = []
        print("-"*10, f"indexing {'train' if self.train else 'test'} data", "-"*10)
        for data_file_path in tqdm(data_file_pathes):
            target = int(data_file_path.split('/')[-2])
            self.data_path.append(data_file_path)
            self.targets.append(target)

    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, img_file_path, target) where target is index of the target class.
        """
        img_file_path, target = self.data_path[index], self.targets[index]
        img = Image.open(img_file_path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.return_masked:
            masked_img_file_path = self.masked_data_path[index]
            masked_img = Image.open(masked_img_file_path)
            if self.transform is not None:
                masked_img = self.transform(masked_img)
            return img, img_file_path, target, masked_img
        return img, img_file_path, target
