import os
import pickle
import torch
import numpy as np
from cv2 import imread
from torch.utils.data import Dataset
from torchvision.transforms.v2.functional import to_tensor

from tqdm import trange
from PIL import Image
import pickle as pkl
import torchvision.transforms as T

from util.utils import recursive_glob


class SUNRGBDDataset(Dataset):
    def __init__(self, root, split='train', transform=None, target_transform=None):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        self.rgb_files = sorted(recursive_glob(rootdir=os.path.join(root), suffix='jpg'))
        self.rgb_files = [x for x in self.rgb_files if 'fullres' not in x]

        self.depth_files = sorted(recursive_glob(rootdir=os.path.join(root), suffix='png'))
        self.depth_files = [x for x in self.depth_files if 'fullres' not in x and 'bfx' not in x and 'depthRaw' not in x]

        if len(self.rgb_files) == 0:
            raise (RuntimeError("Empty dataset - found no image pairs under \n" + root))

        print('found {:d} image pairs'.format(len(self.rgb_files)))

    def __len__(self):
        return len(self.rgb_files)

    def __getitem__(self, idx):
        rgb = Image.open(self.rgb_files[idx])
        depth = Image.open(self.depth_files[idx])

        # apply train/val transforms
        if self.transform is not None:
            input_tensor = self.transform(rgb)
        if self.target_transform is not None:
            depth_tensor = self.target_transform(depth)

        return input_tensor, depth_tensor / 1e3