import os

import torch
import torch.utils.data as data
from torchvision import transforms, datasets

import utils

def get_normal_data(args):
    if args.dataset == 'mnist':
        transform = transforms.Compose([
                transforms.Resize(32),
                transforms.ToTensor(),
        ])
        rotate = int(args.rotate*2)
        translate = int(args.translate*2) / 32
        generalization_transform = transforms.Compose([
                transforms.Resize(32),
                transforms.RandomAffine(
                        degrees=rotate,
                        translate=(translate, translate),
                ),
                transforms.ToTensor(),
        ])
    
        train_data = datasets.MNIST(
                root=args.data,
                download=True,
                train=True,
                transform=transform
        )
        test_data = datasets.MNIST(
                root=args.data,
                download=True,
                train=False,
                transform=transform,
        )
        generalization_data = datasets.MNIST(
                root=args.data,
                download=True,
                train=False,
                transform=generalization_transform,
        )
        args.num_classes = 10
        args.num_channels = 1

    train_loader = data.DataLoader(
            train_data, 
            batch_size=args.batch_size,
            pin_memory=True,
            num_workers=int(4),
            shuffle=True,
            drop_last=True,
    )

    test_loader = data.DataLoader(
            test_data, 
            batch_size=128,
            pin_memory=False,
            num_workers=int(0),
            shuffle=False,
            drop_last=False,
    )
    generalization_loader = data.DataLoader(
            generalization_data, 
            batch_size=128,
            pin_memory=False,
            num_workers=int(0),
            shuffle=False,
            drop_last=False,
    )

    return train_loader, test_loader, generalization_loader, args

def get_augmented_data(args):
    if args.dataset == 'mnist':
        args.num_classes = 10
        args.num_channels = 1
        args.size = 28

        transform = transforms.Compose([
                transforms.Resize(32),
                transforms.ToTensor(),
#                transforms.Normalize((0.5,), (0.5,)),
        ])
        rotate = int(args.rotate*2)
        translate = int(args.translate*2) / 32
        generalization_transform = transforms.Compose([
                transforms.Resize(32),
                transforms.RandomAffine(
                        degrees=rotate,
                        translate=(translate, translate),
                ),
                transforms.ToTensor(),
        ])
        train_transform = utils.RandomTransformations(
                num_samples=args.num_samples,
                rotate=args.rotate,
                translate=args.translate,
                size=args.size,
                normalize=False,
        )

        train_data = datasets.MNIST(
                root=args.data,
                transform=train_transform,
                download=True,
                train=True,
        )
        test_data = datasets.MNIST(
                root=args.data,
                download=True,
                train=False,
                transform=transform
        )
        generalization_data = datasets.MNIST(
                root=args.data,
                download=True,
                train=False,
                transform=generalization_transform,
        )

    train_loader = data.DataLoader(
            train_data, 
            batch_size=args.batch_size,
            pin_memory=True,
            num_workers=int(4),
            shuffle=True,
            drop_last=True,
    )

    test_loader = data.DataLoader(
            test_data, 
            batch_size=128,
            pin_memory=False,
            num_workers=int(0),
            shuffle=False,
            drop_last=False,
    )
    generalization_loader = data.DataLoader(
            generalization_data, 
            batch_size=128,
            pin_memory=False,
            num_workers=int(0),
            shuffle=False,
            drop_last=False,
    )
    return train_loader, test_loader, generalization_loader, args
