from torchvision import transforms
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset
import os
from datasets import load_dataset

class CountBenchDataset(Dataset):
    def __init__(self, root: str, split: str, transform=None):
        assert split in ['train', 'val', 'test']

        self.data_dir = root
        self.dataset = load_dataset('nielsr/countbench')['train']
        self.data = []
        for batch in self.dataset:
            if batch['image'] is not None:
                self.data.append(batch)

        self.preprocess = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        batch = self.data[idx]
        img = batch['image']
        if self.preprocess is not None:
            img = self.preprocess(img)
        if img.shape[0] == 1:
            img = img.repeat(3, 1, 1)
        gt_cnt = batch['number']

        return img, gt_cnt
