import torch
import torchvision

from utils.DataSpliter import DataSpliter


def init_dataset(args):
    if args.dataset == 'mnist':
        train_data = torchvision.datasets.MNIST(root=f'{args.data_dir}', train=True,
                                                download=False, transform=torchvision.transforms.ToTensor())
        test_data = torchvision.datasets.MNIST(root=f'{args.data_dir}', train=False,
                                               download=True, transform=torchvision.transforms.ToTensor())
        return ssp_data_spliter(args, train_data, test_data)
    elif args.dataset == 'cifar10':
        train_data = torchvision.datasets.CIFAR10(root=f'{args.data_dir}', train=True,
                                                download=True, transform=torchvision.transforms.ToTensor())
        test_data = torchvision.datasets.CIFAR10(root=f'{args.data_dir}', train=False,
                                               download=True, transform=torchvision.transforms.ToTensor())
        return ssp_data_spliter(args, train_data, test_data)

    elif args.dataset == 'cifar100':
        train_data = torchvision.datasets.CIFAR100(root=f'{args.data_dir}', train=True,
                                                download=True, transform=torchvision.transforms.ToTensor())
        test_data = torchvision.datasets.CIFAR100(root=f'{args.data_dir}', train=False,
                                               download=True, transform=torchvision.transforms.ToTensor())
        return ssp_data_spliter(args, train_data, test_data)


def ssp_data_spliter(args, train_data, test_data):
    if args.iid == 0:
        train_spliter = DataSpliter(args, train_data, split_mode='noniid',
                                    overlap=args.data_overlap, data_type=args.dataset,
                                    num_nodes=args.world_size, bsz=args.batch_size)
    else:
        train_spliter = DataSpliter(args, train_data, split_mode='iid',
                                    overlap=args.data_overlap, data_type=args.dataset,
                                    num_nodes=args.world_size, bsz=args.batch_size)
    train_loaders = train_spliter.get_loaders()
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
    data_size_partition = [len(loader.dataset) for loader in train_loaders]
    return train_loaders, test_loader

