import gym
import d4rl
import torch
import numpy as np
from omegaconf import OmegaConf

from dataset.generate import generate_preference_dataset
from models import load_policy


def load_dataset(
    config: OmegaConf, 
    seg_len: int=30, 
    device: torch.device='cpu', 
    force_generate: bool=False,
    return_state: bool=False,
    ):
    try:
        assert not force_generate
        dataset = np.load(config.data.path, allow_pickle=True)
        print('Loaded dataset from:', config.data.path)
    except:
        print('Failed to load data from cache, or forced regeneration. Generating new dataset.')
        dataset = generate_preference_dataset(
            gym.make(config.env.name),
            load_policy(config, device=device, path=config.policy.path),
            data_path=config.data.path,
            num_pairs=config.data.num_pairs,
            seg_len=config.data.seg_len,
            device=device,
        )
    dataset = PreferenceDataset(
        dataset, seg_len=seg_len, 
        return_state=return_state
    )
    return dataset


class PreferenceDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, return_state=False, seg_len=30):
        self.dataset = dataset
        self.seg_len = seg_len
        self.return_state = return_state

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        y1_s, y1_a, y0_s, y0_a, context = self.dataset[idx]
        if self.return_state:
            return (
                torch.Tensor(context),
                torch.Tensor(y1_s[:self.seg_len, :]), 
                torch.Tensor(y1_a[:self.seg_len, :]), 
                torch.Tensor(y0_s[:self.seg_len, :]), 
                torch.Tensor(y0_a[:self.seg_len, :]), 
            )
        return (
            torch.Tensor(context).flatten(), 
            torch.Tensor(y1_a[:self.seg_len, :]).flatten(), 
            torch.Tensor(y0_a[:self.seg_len, :]).flatten(),
        )

