from .episode_container import *

import numpy as np
from torchvision.transforms.functional import affine
from torchvision.transforms import ColorJitter
import os


class Transform:
    def __init__(self, max_translate=6, brightness=0.1, contrast=0.1, hue=0.02):
        self.color_jitter = ColorJitter(
            brightness=brightness,
            contrast=contrast,
            hue=hue
        )
        self.max_translate = max_translate

    def __call__(self, image, eval=False):
        if not eval:
            translate = list(np.random.randint(-self.max_translate, self.max_translate + 1, size=2))
            image = affine(image, angle=0, translate=translate, scale=1, shear=[0])
            image = self.color_jitter(image)
        image = 2. * image / 255. - 1.
        return image


def load_dataset(env_name, dataset_cfg, env=None):
    if env_name == 'calvin':
        train_data_pths = [
            os.path.join(dataset_cfg['data_dir'], 'train_%d.pkl' % i) for i in range(35)
        ]
        val_data_pths = [
            os.path.join(dataset_cfg['data_dir'], 'validation_%d.pkl' % i) for i in range(6)
        ]
        if dataset_cfg['use_skill']:
            dataset_cfg['train_data'] = CalvinEpisodeContainer(
                data_pths=train_data_pths,
                modularities=('observations', 'skills')
            )
            dataset_cfg['val_data'] = CalvinEpisodeContainer(
                data_pths=val_data_pths,
                modularities=('observations', 'skills')
            )
        else:
            dataset_cfg['train_data'] = CalvinEpisodeContainer(
                data_pths=train_data_pths,
                modularities=('observations', 'actions')
            )
            dataset_cfg['val_data'] = CalvinEpisodeContainer(
                data_pths=val_data_pths,
                modularities=('observations', 'actions')
            )
    else:
        dataset_cfg['train_data'] = D4rlEpisodeContainer(
            env, modularities=('observations', 'actions')
        )
        dataset_cfg['val_data'] = dataset_cfg['train_data']

