import os

import torch
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.utils.data as data_utils
import numpy as np

from utils.fashion_mnist import MNIST, FashionMNIST

class SubDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, sample_idx):
        self.dataset = dataset
        self.sample_idx = sample_idx
        
    def __getitem__(self, index):
        return self.dataset[self.sample_idx[index]]
        
    def __len__(self):
        return len(self.sample_idx)

def get_data_loader(args, rank):

    dataroot = os.path.join(args.dataroot, args.dataset)
    if args.dataset == 'mnist':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, )),
        ])
        # train_dataset = MNIST(root=dataroot, train=True, download=args.download, transform=trans)
        # test_dataset = MNIST(root=dataroot, train=False, download=args.download, transform=trans)

        # mnist_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
        train_dataset = torchvision.datasets.MNIST(root='./dataset', train=True, transform=trans, download=True)
        test_dataset = torchvision.datasets.MNIST(root='./dataset', train=False, transform=trans, download=True)

    elif args.dataset == 'fashion-mnist':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, )),
        ])
        train_dataset = FashionMNIST(root=dataroot, train=True, download=args.download, transform=trans)
        test_dataset = FashionMNIST(root=dataroot, train=False, download=args.download, transform=trans)

    elif args.dataset == 'cifar':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        train_dataset = dset.CIFAR10(root=dataroot, train=True, download=args.download, transform=trans)
        test_dataset = dset.CIFAR10(root=dataroot, train=False, download=args.download, transform=trans)

    elif args.dataset == 'stl10':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
        ])
        train_dataset = dset.STL10(root=dataroot, split='train', download=args.download, transform=trans)
        test_dataset = dset.STL10(root=dataroot,  split='test', download=args.download, transform=trans)

    # Check if everything is ok with loading datasets
    assert train_dataset
    assert test_dataset

    sorted_indices = torch.argsort(torch.tensor(train_dataset.targets))
    sorted_data = [train_dataset.data[i] for i in sorted_indices]
    sorted_targets = [train_dataset.targets[i] for i in sorted_indices]

    # Put sorted data and labels back into the original dataset object
    train_dataset.data = np.stack(sorted_data)
    train_dataset.targets = torch.tensor(sorted_targets)

    m = 2
    min_label_size = 0
    i_start_pos = [0]
    for label in range(10):
        label_indices = (train_dataset.targets == label).nonzero(as_tuple=True)[0]
        label_size = len(label_indices)
        i_start_pos.extend([i_start_pos[-1] + label_size])
        if label == 0:
            min_label_size = label_size
        elif label_size < min_label_size:
            min_label_size = label_size

    # mnist_label_size = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
    label_chunk_num = (args.nodes * m - 1) // 10 + 1
    label_per_node_size = min_label_size // label_chunk_num

    indices = list(range(len(train_dataset)))
    chunk_indices = []
    for label in range(10):
        for i in range(label_chunk_num):
            start_pos = i_start_pos[label] + i * label_per_node_size
            chunk_indices.append(indices[start_pos : start_pos + label_per_node_size])

    shuffled_chunk = torch.tensor(chunk_indices)[torch.randperm(len(chunk_indices))]
    # if rank == 0:
    #     print("shuffled_chunk.shape : ", shuffled_chunk.shape)

    train_idx = []
    train_idx = shuffled_chunk[rank * m:(rank + 1) * m].view(-1).tolist()

    train_dataset_split = SubDataset(train_dataset, train_idx)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset_split, batch_size=args.batch_size//args.nodes, shuffle=True, sampler=None, num_workers=10)
    
    test_sampler_distribute = torch.utils.data.distributed.DistributedSampler(
            dataset=test_dataset,
            num_replicas=args.nodes,
            rank=rank)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=args.batch_size//args.nodes,  sampler=test_sampler_distribute, num_workers=10)

    return train_dataloader, test_dataloader
