import torch
from torchvision import transforms
from .omniglot import omniglot
from .fixed_mnist import fixedMNIST

from torch.utils.data.dataset import Dataset

''' Adopted from https://github.com/yoonholee/pytorch-vae'''

def data_loaders(args):
    if args.dataset == 'omniglot':
        loader_fn, root = omniglot, './dataset/omniglot'
    elif args.dataset == 'fixedmnist':
        loader_fn, root = fixedMNIST, './dataset/fixedmnist'

    if args.dataset_dir != '': root = args.dataset_dir

    kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}

    train_loader = torch.utils.data.DataLoader(
        loader_fn(root, train=True, download=True, transform=transforms.ToTensor()),
        batch_size=args.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(  # need test bs <=64 to make L_5000 tractable in one pass
        loader_fn(root, train=False, download=True, transform=transforms.ToTensor()),
        batch_size=args.test_batch_size, shuffle=False, **kwargs)

    return train_loader, test_loader