from functools import partial
from typing import Callable

import d4rl  # noqa
import gym

import pickle
from utilities.utils import compose

from .dataset import Dataset
from .preprocess import clip_actions, pad_trajs_to_dataset, split_to_trajs


def get_dataset(
    env,
    max_traj_length: int,
    termination_penalty: float = None,
    include_next_obs: bool = False,
    clip_to_eps: bool = False,  # disable action clip for debugging purpose
):
    preprocess_fn = compose(
        partial(
            pad_trajs_to_dataset,
            max_traj_length=max_traj_length,
            termination_penalty=termination_penalty,
            include_next_obs=include_next_obs,
        ),
        split_to_trajs,
        partial(
            clip_actions,
            clip_to_eps=clip_to_eps,
        ),
    ) #应该是倒着进行的，先clip actions,再split_to_trajs,最后pad_trajs
    return D4RLDataset(env, preprocess_fn=preprocess_fn)

class D4RLDataset(Dataset):
    def __init__(self, env: str, preprocess_fn: Callable, **kwargs):
        # self.raw_dataset = dataset = env.get_dataset()
        dataset_dir =  f'/mnt/disk1/yxd/CORL/preferenceRL/pref_code/meta_pref/data/mt10_{env}.pkl'
        with open(dataset_dir, 'rb') as f:
            self.raw_dataset = dataset = pickle.load(f)
        data_dict = preprocess_fn(dataset)
        super().__init__(**data_dict, **kwargs)
