# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch
from torchvision import datasets



class CIFAR10SubSet(torch.utils.data.Dataset):  # TODO: use torch.utils.data.Dataset with batch sampling
    def __init__(self, root, train=True, transform=None, download=True, returns="all", num_sample=None):
        """CIFAR-10 dataset with index to extract a mini-batch based on given batch indices
        Useful for VFL training

        Args:
            root: data root
            data_idx: to specify the data for a particular client site.
                If index provided, extract subset, otherwise use the whole set
            train: whether to use the training or validation split (default: True)
            transform: image transforms
            download: whether to download the data (default: False)
            returns: specify which data the client has
        Returns:
            A PyTorch dataset
        """
        self.root = root
        self.train = train
        self.transform = transform
        self.download = download
        self.returns = returns
        self.data, self.target = self.__build_cifar_subset__(num_sample=num_sample)
        #print(self.data.shape)

    def __build_cifar_subset__(self, num_sample):
        # if index provided, extract subset, otherwise use the whole set
        cifar_dataobj = datasets.CIFAR10(self.root, self.train, self.transform, download=self.download)
        data = cifar_dataobj.data
        target = np.array(cifar_dataobj.targets)

        if num_sample != None:
            # sort labels
            idxs = np.arange(len(target))
            labels = np.array(target)
            idxs_labels = np.vstack((idxs, labels))
            idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
            idxs = idxs_labels[0, :]

            samples_per_class = int(len(target)/10)
            sample_idx = []
            for c in range(10):
                sample_idx_c = np.random.choice(samples_per_class, int(num_sample/10))
                sample_idx_c += c*samples_per_class
                sample_idx.append(sample_idx_c)

            sample_idx = np.concatenate(sample_idx)
            sample_idx = idxs[sample_idx]

            data = data[sample_idx]
            target = target[sample_idx]

        return data, target

    def __getitem__(self, index):
        data, target = self.data[index], self.target[index]
        if self.transform is not None:
            data = self.transform(data)
        return data, target
    
    def __len__(self):
        return len(self.data)
    


class CIFAR10withRepMask(torch.utils.data.Dataset):
    def __init__(self, args, root, dev_root, train=True, transform=None, download=True):
        self.root = root
        self.dev_root = dev_root
        self.train = train
        self.transform = transform
        self.download = download

        cifar_dataobj = datasets.CIFAR10(self.root, self.train, self.transform, download=self.download)
        self.data = cifar_dataobj.data
        self.target = cifar_dataobj.targets
        dev = torch.tensor(np.nan_to_num(np.load(dev_root),nan=np.inf))
        thresh = torch.quantile(abs(dev), dim=1, keepdim=True, q=args.compress)
        self.repMask = abs(dev)>thresh

    def __getitem__(self, index):
        data, target, mask = self.data[index], self.target[index], self.repMask[index]
        if self.transform is not None:
            data = self.transform(data)
        return data, target, mask
    
    def __len__(self):
        return len(self.data)


class CIFAR10SubsetWithRepMask(torch.utils.data.Dataset):
    def __init__(self, args, root, dev_root, train=True, transform=None, download=True, num_sample=None):
        self.root = root
        self.dev_root = dev_root
        self.train = train
        self.transform = transform
        self.download = download

        dev = torch.tensor(np.nan_to_num(np.load(dev_root),nan=np.inf))
        thresh = torch.quantile(abs(dev), dim=1, keepdim=True, q=args.compress)
        repMask = abs(dev)>thresh

        self.data, self.target, self.repMask = self.__build_cifar_subset__(num_sample=num_sample, masks = repMask)


    def __build_cifar_subset__(self, num_sample, masks):
        # if index provided, extract subset, otherwise use the whole set
        cifar_dataobj = datasets.CIFAR10(self.root, self.train, self.transform, download=self.download)
        data = cifar_dataobj.data
        target = np.array(cifar_dataobj.targets)

        if num_sample != None:
            # sort labels
            idxs = np.arange(len(target))
            labels = np.array(target)
            idxs_labels = np.vstack((idxs, labels))
            idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
            idxs = idxs_labels[0, :]

            samples_per_class = int(len(target)/10)
            sample_idx = []
            for c in range(10):
                sample_idx_c = np.random.choice(samples_per_class, int(num_sample/10))
                sample_idx_c += c*samples_per_class
                sample_idx.append(sample_idx_c)

            sample_idx = np.concatenate(sample_idx)
            sample_idx = idxs[sample_idx]

            data = data[sample_idx]
            target = target[sample_idx]
            masks = masks[sample_idx]

        return data, target, masks


    def __getitem__(self, index):
        data, target, mask = self.data[index], self.target[index], self.repMask[index]
        if self.transform is not None:
            data = self.transform(data)
        return data, target, mask
    
    def __len__(self):
        return len(self.data)