import torchvision
from torch.utils.data import DataLoader
from loader.transforms.random_erasing import RandomErasing
from torchvision import transforms

from PIL import Image
from config import cfg

from loader.cifar10 import get_cifar10
from loader.cifar100 import get_cifar100
from loader.fashion_mnist import get_fashion_mnist
from loader.imagenet import get_imagenet, get_imagenet_mini, get_imagenet_subset
from loader.cub200 import get_cub200
from loader.big_cub200 import get_big_cub200

def get_loader():
    pair = {
        'fashion.mnist': get_fashion_mnist,
        'cifar10': get_cifar10,
        'cifar100': get_cifar100,
        'imagenet': get_imagenet,
        'imagenet.subset': get_imagenet_subset,
        'imagenet.mini': get_imagenet_mini,
        'cub200': get_cub200,
        'cub200.big': get_big_cub200
    }

    return pair[cfg.data.type]()
