from random import random
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from PIL import Image
import os
import random


class Img2ImgDataset(Dataset):
    def __init__(self, input_path=None, GT_path=None, data_transform=None, rotation=False):
        
        self.inputs_dir = input_path
        self.GTs_dir = GT_path
        self.rotation = rotation

        if data_transform is None:
            self.data_transforms = transforms.Compose([transforms.ToTensor()])
        else:
            self.data_transforms = data_transform

        self.inputs = sorted(os.listdir(input_path))
        self.GTs = sorted(os.listdir(GT_path))

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input = Image.open(self.inputs_dir + self.inputs[idx])
        GT = Image.open(self.GTs_dir + self.GTs[idx])

        if self.rotation == True:
            degree = random.choice([0,90,180,270])
            input = input.rotate(degree)
            GT = GT.rotate(degree)
        
        input = self.data_transforms(input)
        GT = self.data_transforms(GT)
        return input, GT