import json
import os
import glob
import random

from .base import BaseDataset_outpaint, BaseDataset_warp, WarpDataset, OutpaintDataset


class SingleCategory(BaseDataset_warp):
    """
    SingleCategory dataset

    Args:
        root_path (str): path to the dataset
        image_size (int): size of the images
        normalize (bool): whether to normalize the images to [-1, 1]
        normalize_depth (bool): whether to normalize the depth maps to [-1, 1]
        prepocess_depth (str): how to preprocess the depth maps (inputs from the dataset are disparity maps)
            - 'none': no preprocessing
            - 'to_depth': disparity map, to depth map
            - 'disparity_minmax': disparity map, min-max normalization, min=0, max=1
            - 'depth_minmax': depth map, min-max normalization, min=0, max=1
            - 'z_buffer': perspective projection to [0, 1]
        near (float): near plane for perspective projection
        far (float): far plane for perspective projection
    """

    def __init__(self,
        root_path,
        data_txt,
        image_size,
        normalize=False,
        normalize_depth=False,
        prepocess_depth='none',
        near=0.5,
        far=100,
        depth_cfg=None,
    ):
        super().__init__(root_path, data_txt, image_size, normalize, normalize_depth, prepocess_depth, near, far, depth_cfg)

    def get_fileinfo(self):
        
        images_list = []
        
        with open(self.data_txt, 'r') as file:
            images_list = file.readlines()

        images_list = [line.rstrip('\n') for line in images_list]
        # sample_size = int(len(images_list) * 0.01)
        # images_list = random.sample(images_list, sample_size)

        self.images = [os.path.join(self.root_path, 'images', name) for name in images_list]
        assert len(self.images) > 0, "Can't find data; make sure you specify the path to your dataset"
        # self.images.sort()

        self.depths = []
        self.depths = [os.path.join(self.root_path, 'depths', name.split('/')[-1].split('.')[0] + '.npz') for name in images_list]
        
       
        # json.dump({
        #         'images': self.images,
        #         'depths': self.depths
        #     }, open(os.path.join(self.root_path, 'dataset.json'), 'w'))



class SingleCategoryWarp(WarpDataset, SingleCategory):
    def __init__(
        self,
        root_path,
        data_txt,
        image_size,
        normalize=False,
        normalize_depth=False,
        prepocess_depth='none',
        near=0.5,
        far=100,
        augments=[],
        std=0.15,
        gen_inpainting_data=False,
        viewset='3x9',
        forward_warp='3daware',
        depth_cfg=None,
    ):
        super().__init__(root_path, data_txt, image_size, normalize, normalize_depth, prepocess_depth, near, far, augments, std, gen_inpainting_data, viewset, forward_warp, depth_cfg)


class SingleCategoryOut(OutpaintDataset, SingleCategory):
    def __init__(
        self,
        root_path,
        data_txt,
        image_size,
        normalize=False,
        normalize_depth=False,
        prepocess_depth='none',
        near=0.5,
        far=100,
        augments=[],
        std=0.15,
        gen_inpainting_data=False,
        viewset='3x9',
        forward_warp='3daware',
        depth_cfg=None,
    ):
        super().__init__(root_path, data_txt, image_size, normalize, normalize_depth, prepocess_depth, near, far, augments, std, gen_inpainting_data, viewset, forward_warp, depth_cfg)
