import random
import torch
from torch_geometric.loader import NeighborLoader

class DataLoaders:
    def __init__(
            self,
            datasets,
            batch_size: int,
            workers: int = 0,
            collate_fn=None,
            shuffle_train=True,
            shuffle_val=False
    ):
        super().__init__()
        self.train_dataset, self.valid_dataset, self.test_dataset = datasets

        self.batch_size = batch_size

        self.workers = workers
        self.collate_fn = collate_fn
        self.shuffle_train, self.shuffle_val = shuffle_train, shuffle_val

        self.train = self.train_dataloader()
        self.valid = self.val_dataloader()
        self.test = self.test_dataloader()

    def train_dataloader(self):
        return self._make_dloader("train", shuffle=True)

    def val_dataloader(self):
        return self._make_dloader("val", shuffle=self.shuffle_val)

    def test_dataloader(self):
        return self._make_dloader("test", shuffle=False)

    def _make_dloader(self, split, shuffle=False):
        dataset = None
        if split == 'train':
            dataset = self.train_dataset
        elif split == 'val':
            dataset = self.valid_dataset
        elif split == 'test':
            dataset = self.test_dataset
        if len(dataset) == 0: return None
        indexes = torch.arange(dataset.y.shape[0])
        if hasattr(dataset, 'mask'):
            input_nodes = indexes[dataset.mask == 1]
        else:
            input_nodes = indexes
        if split == 'train':
            if len(input_nodes) > 60000:
                sampled_input_nodes = random.choices(input_nodes, k=60000)
            else:
                sampled_input_nodes = input_nodes
            random.shuffle(sampled_input_nodes)
            input_nodes = torch.LongTensor(sampled_input_nodes)
            print('Num Nodes:', len(input_nodes))

        return NeighborLoader(
            dataset,
            shuffle=shuffle,
            input_nodes=input_nodes,
            batch_size=self.batch_size,
            num_neighbors=[10]
        )


