import numpy as np
import torch
from torch.utils.data import Dataset


class PlayDataset(Dataset):
    def __init__(self, data_container, min_skill_length, max_skill_length, use_padding, percentage=1.0):
        self.data_container = data_container
        self.min_skill_length, self.max_skill_length = min_skill_length, max_skill_length
        self.indices = []
        for i, epi_length in enumerate(self.data_container.epi_lengths):
            self.indices += [(i, t) for t in range(epi_length - min_skill_length)]
        self.indices = np.array(self.indices)

        if percentage < 1.0:
            total_size = len(self.indices)
            size = int(total_size * percentage)
            self.indices = self.indices[np.random.choice(np.arange(total_size), size, replace=False)]
        self.use_padding = use_padding

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        epi_idx, t = self.indices[idx]
        max_skill_length = np.minimum(
            self.max_skill_length, self.data_container.epi_lengths[epi_idx] - t - 1
        )
        skill_length = np.random.randint(self.min_skill_length, max_skill_length + 1)
        trajectory = self.data_container[epi_idx, t:t + skill_length]
        trajectory['goals'] = self.data_container[epi_idx, t + skill_length]['observations']
        if skill_length < self.max_skill_length and self.use_padding:
            trajectory = self.padding(trajectory)
        return trajectory

    def padding(self, trajectory):
        skill_length = len(trajectory['observations'])
        pad_length = self.max_skill_length - skill_length
        observation_pad = np.repeat(trajectory['goals'][np.newaxis, ...], pad_length, axis=0)
        action_pad = np.repeat(trajectory['actions'][-1:], pad_length, axis=0)
        action_pad[:, :-1] = 0.     # preserve only gripper control
        trajectory['observations'] = np.concatenate(
            [trajectory['observations'], observation_pad], axis=0
        )
        trajectory['actions'] = np.concatenate(
            [trajectory['actions'], action_pad], axis=0
        )
        return trajectory


class GCRLDataset(Dataset):
    def __init__(self, data_container, min_dt, max_dt, use_skill=False, skill_length=0, percentage=1.0):
        self.data_container = data_container
        self.min_dt = min_dt
        self.max_dt = max_dt
        self.indices = []
        for i, epi_length in enumerate(self.data_container.epi_lengths):
            self.indices += [(i, t) for t in range(epi_length - min_dt - 1)]
        self.indices = np.array(self.indices)
        if percentage < 1.0:
            total_size = len(self.indices)
            size = int(total_size * percentage)
            self.indices = self.indices[np.random.choice(np.arange(total_size), size, replace=False)]
        self.use_skill = use_skill
        self.skill_length = skill_length    # 0 for not using pretrained skill

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        epi_idx, t = self.indices[idx]
        dt = np.random.randint(self.min_dt, self.max_dt + 1)
        transition = self.data_container[epi_idx, t]
        epi_length = self.data_container.epi_lengths[epi_idx]
        t_goal = np.clip(t + dt, 0, epi_length - 1)
        transition['dt'] = t_goal - t
        transition['goals'] = self.data_container[epi_idx, t_goal]['observations']
        # transition['next_observations'] = self.data_container[epi_idx, :]['observations']
        if self.use_skill:
            skill_idx = np.minimum(dt - self.min_dt, self.skill_length - self.min_dt)
            transition['actions'] = transition['skills'][skill_idx, :]
        return transition


class MilestoneDataLoader:
    def __init__(self, data_container, max_interval, horizon, batch_size, percentage=1.0):
        self.data_container = data_container
        self.max_interval = max_interval
        self.horizon = horizon
        self.epi_idxs_n_lengths = []
        self.batch_size = batch_size
        for i, length in enumerate(data_container.epi_lengths):
            if length >= horizon:
                self.epi_idxs_n_lengths.append((i, length))
        self.percentage = percentage

    def sample(self, batch_size=None):
        observations, intervals = [], []
        batch_size = batch_size or self.batch_size
        for _ in range(batch_size):
            sample_idx = np.random.randint(len(self.epi_idxs_n_lengths))
            epi_idx, epi_length = self.epi_idxs_n_lengths[sample_idx]
            max_interval = np.minimum((epi_length - 1) // (self.horizon - 1), self.max_interval)
            interval = np.random.randint(1, max_interval + 1)
            t = np.random.randint(epi_length - interval * (self.horizon - 1))
            timesteps = t + interval * np.arange(self.horizon)
            observations.append(self.data_container[epi_idx, timesteps]['observations'])
            intervals.append((interval - np.random.rand()) / self.max_interval)
        batch = {
            'observations': torch.as_tensor(np.stack(observations)),
            'intervals': torch.as_tensor(np.stack(intervals), dtype=torch.float32).unsqueeze(1)
        }
        return batch
