from torch.utils.data import Dataset
from torch import randn, Tensor, device


class Input_Dataset(Dataset):

    def __init__(self, input_std=1.0, input_dim=2, datanum=4096, online=True, device=device('cpu')):

        super(Input_Dataset).__init__()

        self.std = input_std
        self.dim = input_dim
        self.datanum = datanum
        self.online = online
        self.device = device

        self.generate_data()

    def __len__(self):
        return self.datanum

    def generate_data(self):
        if self.online is not True:
            self.data = self.std * randn(self.datanum, self.dim, device=self.device)

    def __getitem__(self, index):

        if self.online:
            if isinstance(index, Tensor):
                return index, self.std * randn(index.shape[0], self.dim, device=self.device)
            elif isinstance(index, int):
                return index, self.std * randn(self.dim, device=self.device)
            else:
                return index, self.std * randn(self.datanum, self.dim, device=self.device)[index]
        else:
            return index, self.data[index]
