import os

import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random

from config import args


class SubDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, sample_idx):
        self.dataset = dataset
        self.sample_idx = sample_idx
        
    def __getitem__(self, index):
        return self.dataset[self.sample_idx[index]]
        
    def __len__(self):
        return len(self.sample_idx)


def get_dataset(nodes, rank, batchsize, label_num, s_rate):
  if args.dataset == 'mnist' or args.dataset == 'cifar10':
    # Configure data loader
    if args.dataset == 'mnist':
        os.makedirs("./dataset/mnist", exist_ok=True)
        mnist_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
        train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=mnist_transforms, download=True)
        test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=mnist_transforms, download=True)
    elif args.dataset == 'cifar10':
        train_dataset = torchvision.datasets.CIFAR10(
            root='./data', 
            train=True,
            download=True, 
            transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
        )
        test_dataset = torchvision.datasets.CIFAR10(
            root='./data', 
            train=False,
            download=True, 
            transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
        )

    # Obtain indexes sorted by labels, and use these indexes to sort data and labels
    if args.dataset == 'mnist':
        sorted_indices = torch.argsort(train_dataset.targets)
        sorted_data = [train_dataset.data[i] for i in sorted_indices]
        sorted_targets = [train_dataset.targets[i] for i in sorted_indices]

        train_dataset.data = torch.stack(sorted_data)
        train_dataset.targets = torch.tensor(sorted_targets)

    elif args.dataset == 'cifar10':
        sorted_indices = torch.argsort(torch.tensor(train_dataset.targets))
        sorted_data = [train_dataset.data[i] for i in sorted_indices]
        sorted_targets = [train_dataset.targets[i] for i in sorted_indices]

        train_dataset.data = np.stack(sorted_data)
        train_dataset.targets = torch.tensor(sorted_targets)

    # data_distribution_method = 'shuffle'
    data_distribution_method = 'm_classes'

    if data_distribution_method == 'm_classes':
        min_label_size = 0
        i_start_pos = [0]
        for label in range(10):
            label_indices = (train_dataset.targets == label).nonzero(as_tuple=True)[0]
            label_size = len(label_indices)
            i_start_pos.extend([i_start_pos[-1] + label_size])
            if label == 0:
                min_label_size = label_size
            elif label_size < min_label_size:
                min_label_size = label_size

        # mnist_label_size = [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
        label_chunk_num = (nodes * label_num - 1) // 10 + 1
        label_per_node_size = min_label_size // label_chunk_num

        indices = list(range(len(train_dataset)))
        chunk_indices = []
        for label in range(10):
            label_chunk_indices = []
            for i in range(label_chunk_num):
                start_pos = i_start_pos[label] + i * label_per_node_size
                label_chunk_indices.append(indices[start_pos : start_pos + label_per_node_size])
            chunk_indices.append(label_chunk_indices)

        allocation_matrix = generate_matrix(nodes, label_num)
        if rank == 0:
            print("allocation_matrix : \n", allocation_matrix)

        train_idx = []
        for label in range(10):
            if allocation_matrix[rank, label] == 1:
                train_idx.extend(chunk_indices[label][np.sum(allocation_matrix[:rank, label])])

    elif data_distribution_method == 'shuffle':

        # Shuffle the first x% of the data and leave the latter part unchanged, using s_rate controls the degree of data heterogeneity
        first_half = []
        second_half = []

        # Separate the front half of each label_ Len and remaining data
        for label in range(10):
            label_indices = (train_dataset.targets == label).nonzero(as_tuple=True)[0]
            if rank == 0:
                print(f"Length of label {label} labels: \n{len(label_indices)}")
            half_len = int(s_rate*len(label_indices))
            first_half.extend(label_indices[:half_len].tolist())
            second_half.extend(label_indices[half_len:].tolist())

        # Shuffle top x% data
        shuffled_half = torch.tensor(first_half)[torch.randperm(len(first_half))]

        # Put shuffle's data back in its original position
        final_indices = []
        portion_size = len(shuffled_half) // 10
        for i in range(10):
            label_indices = (train_dataset.targets == i).nonzero(as_tuple=True)[0]
            half_len = int(s_rate*len(label_indices))

            if i == 9:
                last_portion_size = len(shuffled_half) - len(shuffled_half) // 10 * 9
                final_indices.extend(shuffled_half[i * portion_size: i * portion_size + last_portion_size].tolist())
            else:
                final_indices.extend(shuffled_half[i * portion_size: (i + 1) * portion_size].tolist())
            final_indices.extend(label_indices[half_len:].tolist())

        sorted_data = [train_dataset.data[i] for i in final_indices]
        sorted_targets = [train_dataset.targets[i] for i in final_indices]

        if args.dataset == 'mnist':
            train_dataset.data = torch.stack(sorted_data)
        elif args.dataset == 'cifar10':
            train_dataset.data = np.stack(sorted_data)
        
        train_dataset.targets = torch.tensor(sorted_targets)

        indices = list(range(len(train_dataset)))

        if args.dataset == 'mnist':
            TOTAL_SIZE = 60000
        elif args.dataset == 'cifar10':
            TOTAL_SIZE = 50000

        train_idx = get_train_idx(indices, TOTAL_SIZE, s_rate, nodes, rank)

    train_dataset_split = SubDataset(train_dataset, train_idx)
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset_split, batch_size=batchsize, shuffle=True, sampler=None, num_workers=10)
    
    test_sampler_distribute = torch.utils.data.distributed.DistributedSampler(
        dataset=test_dataset,
        num_replicas=nodes,
        rank=rank)
    test_data_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batchsize,  sampler=test_sampler_distribute, num_workers=10)

  elif args.dataset == 'synthetic':
    raise NotImplementedError
  else:
    raise NotImplementedError

  return train_data_loader, test_data_loader

def sample_latent(shape):
  return torch.randn(shape)

def get_train_idx(indices, TOTAL_SIZE, s_rate, nodes, rank):

    train_idx = []
    per_label_size = TOTAL_SIZE // 10
    per_node_size = TOTAL_SIZE // nodes
    label_shuffled_size = int(per_label_size * s_rate)
    label_unshuffled_size = int(per_label_size * (1-s_rate))
    shuffled_size = int(per_node_size * s_rate)
    unshuffled_size = int(per_node_size * (1-s_rate))
    start_turn = (rank * 10) // nodes
    shuffled_start_pos = start_turn * per_label_size + (shuffled_size * rank) % label_shuffled_size
    unshuffled_start_pos = start_turn * per_label_size + label_shuffled_size + (unshuffled_size * rank) % label_unshuffled_size
    turns = (per_node_size * (rank+1)) // (per_label_size) - (per_node_size * rank) // (per_label_size) + 1

    if per_node_size * (rank+1) % per_label_size == 0:
        turns = turns - 1
    if turns == 1:
        train_idx.extend(indices[shuffled_start_pos:(shuffled_start_pos + shuffled_size)])
        train_idx.extend(indices[unshuffled_start_pos:(unshuffled_start_pos + unshuffled_size)])
    else:
        remain_shuffled_size = shuffled_size
        remain_unshuffled_size = unshuffled_size
        for i in range(turns):
            if i < turns - 1:
                train_idx.extend(indices[shuffled_start_pos:(start_turn + i + 1) * per_label_size - label_unshuffled_size])
                train_idx.extend(indices[unshuffled_start_pos:(start_turn + i + 1) * per_label_size])
                remain_shuffled_size -= (start_turn + i + 1) * per_label_size - label_unshuffled_size - shuffled_start_pos
                remain_unshuffled_size -= (start_turn + i + 1) * per_label_size - unshuffled_start_pos
                shuffled_start_pos = (start_turn + i + 1) * per_label_size
                unshuffled_start_pos = (start_turn + i + 1) * per_label_size + label_shuffled_size
            else:
                train_idx.extend(indices[shuffled_start_pos:(shuffled_start_pos + remain_shuffled_size)])
                train_idx.extend(indices[unshuffled_start_pos:(unshuffled_start_pos + remain_unshuffled_size)])

    return train_idx

def generate_matrix(n, label_num):
    # Initialize an all zero matrix
    matrix = np.zeros((n, 10), dtype=int)
    
    # Add initial feasible solution
    for i in range(n):
        for j in range(label_num):
            matrix[i][(i * label_num + j) % 10] = 1
    
    if n == 1:
        return matrix
    
    # Conduct multiple random exchanges
    for _ in range(1000):
        row1, row2 = random.sample(range(n), 2)
        col1, col2 = random.sample(range(10), 2)
        if matrix[row1][col1] == matrix[row2][col2] and matrix[row1][col2] == matrix[row2][col1]:
            matrix[row1][col1], matrix[row1][col2] = matrix[row1][col2], matrix[row1][col1]
            matrix[row2][col1], matrix[row2][col2] = matrix[row2][col2], matrix[row2][col1]
    
    return matrix
