"""
This code is originally from https://github.com/NVlabs/SPADE
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import os
from data.dmc_dataset import RLDataset

class S2PDataset(RLDataset):
    """ Dataset that loads images from directories
        Use option --label_dir, --image_dir, --instance_dir to specify the directories.
        The images in the directories are sorted in alphabetical order and paired in order.
    """

    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser = RLDataset.modify_commandline_options(parser, is_train)
        print(parser.parse_args().env_type)
        parser.set_defaults(preprocess_mode='resize_and_crop')
        load_size = 128 if is_train else 128
        parser.set_defaults(load_size=load_size)
        parser.set_defaults(crop_size=128)
        parser.set_defaults(display_winsize=128)

        if 'cheetah' in parser.parse_args().env_type:
            state_num = 17
        elif 'walker' in parser.parse_args().env_type:
            state_num = 24
        elif 'ballincup' in parser.parse_args().env_type:
            state_num = 8
        elif 'finger' in parser.parse_args().env_type:
            state_num = 9
        elif 'cartpole' in parser.parse_args().env_type:
            state_num = 5
        elif 'reacher' in parser.parse_args().env_type:
            state_num = 6
        else:
            state_num = None

        if 'light' in parser.parse_args().netG:
            parser.set_defaults(semantic_nc=21*state_num+3)
        else:
            parser.set_defaults(semantic_nc=512+3)
        parser.set_defaults(contain_dontcare_label=False)
        return parser

    def get_paths(self, opt):

        data_paths = opt.dataroot

        return data_paths
