from collections import OrderedDict

import numpy as np
import torch

from rlkit.data_management.replay_buffer import ReplayBuffer
import rlkit.torch.pytorch_util as ptu


class SimpleReplayBuffer(ReplayBuffer):

    def __init__(
        self,
        max_replay_buffer_size,
        observation_dim,
        action_dim,
        env_info_sizes,
    ):
        self._observation_dim = observation_dim
        self._action_dim = action_dim
        self._max_replay_buffer_size = max_replay_buffer_size
        self._observations = np.zeros((max_replay_buffer_size, observation_dim))
        # It's a bit memory inefficient to save the observations twice,
        # but it makes the code *much* easier since you no longer have to
        # worry about termination conditions.
        self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
        self._actions = np.zeros((max_replay_buffer_size, action_dim))
        # Make everything a 2D np array to make it easier for other code to
        # reason about the shape of the data
        self._rewards = np.zeros((max_replay_buffer_size, 1))
        # self._terminals[i] = a terminal was received at time i
        self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
        # Define self._env_infos[key][i] to be the return value of env_info[key]
        # at time i
        self._gamma = np.zeros((max_replay_buffer_size, 1))
        self._q_curr = None
        self._soft_weight = None
        # self._advs = np.zeros((max_replay_buffer_size, 1))
        self._cls_len = np.ones((max_replay_buffer_size, 1))
        self._time_step_val = None

        self._env_infos = {}
        for key, size in env_info_sizes.items():
            self._env_infos[key] = np.zeros((max_replay_buffer_size, size))
        self._env_info_keys = env_info_sizes.keys()

        self._top = 0
        self._size = 0

    def add_sample(self, observation, action, reward, next_observation,
                   terminal, env_info, **kwargs):
        self._observations[self._top] = observation
        self._actions[self._top] = action
        self._rewards[self._top] = reward
        self._terminals[self._top] = terminal
        self._next_obs[self._top] = next_observation

        for key in self._env_info_keys:
            self._env_infos[key][self._top] = env_info[key]
        self._advance()

    def add_sample_only(self, observation, action, reward, next_observation, terminal):
        self._observations[self._top] = observation
        self._actions[self._top] = action
        self._rewards[self._top] = reward
        self._terminals[self._top] = terminal
        self._next_obs[self._top] = next_observation
        self._advance()

    def terminate_episode(self):
        pass

    def _advance(self):
        self._top = (self._top + 1) % self._max_replay_buffer_size
        if self._size < self._max_replay_buffer_size:
            self._size += 1

    def random_batch(self, batch_size, prob=None, get_idx=False):

        if prob is None:
            indices = np.random.choice(self._size, batch_size, replace=False)
        else:
            # p = 1-D array-like
            # The probabilities associated with each entry in a.
            # If not given, the sample assumes a uniform distribution over all entries in a.
            indices = np.random.choice(self._size, batch_size, replace=True, p=prob)

        batch = dict(
            observations=self._observations[indices],
            actions=self._actions[indices],
            rewards=self._rewards[indices],
            terminals=self._terminals[indices],
            next_observations=self._next_obs[indices],
        )
        if self._gamma is not None:
            batch['gamma'] = self._gamma[indices]
        # if self._advs is not None:
        #     batch['advs'] = self._advs[indices]
        # if self._z_diffs is not None:
        #     batch['z_diffs'] = self._z_diffs[indices]
        if self._q_curr is not None:
            batch['q_curr'] = self._q_curr[indices]
        if self._soft_weight is not None:
            batch['soft_weight'] = self._soft_weight[indices]
        if self._cls_len is not None:
            batch['cls_len'] = self._cls_len[indices]
        if self._time_step_val is not None:
            batch['time_step_val'] = self._time_step_val[indices]

        for key in self._env_info_keys:
            assert key not in batch.keys()
            batch[key] = self._env_infos[key][indices]

        if get_idx == True:
            return batch, indices
        else:
            return batch

    def validate_sample(
            self, model, batch_size=1024,
    ) -> tuple:
        ind = np.random.choice(self._size, size=batch_size, replace=False)

        s = torch.FloatTensor(self._observations[ind])
        a = torch.FloatTensor(self._actions[ind])
        ns = torch.FloatTensor(self._next_obs[ind])
        r = torch.FloatTensor(self._rewards[ind])

        hat = model(s, a)
        val_losses = ((hat - torch.cat([ns, r], dim=-1)) ** 2).mean(dim=(1, 2))

        return val_losses

    def calculate_gamma_return(self, gamma=0.99):
        gam_return = np.zeros_like(self._rewards)
        pre_return = 0
        for i in reversed(range(self._size)):
            gam_return[i] = self._rewards[i] + gamma * pre_return * (1 - self._terminals[i])
            pre_return = gam_return[i]

        self._gamma = gam_return

        return self._gamma

    def init_qf(self, qf):
        self._q_curr = np.zeros_like(self._rewards)
        for i in range(0, self._size, 10000):
            end_idx = min(self._size, i + 10000)

            bat_s = ptu.from_numpy(self._observations[i:end_idx])
            bat_a = ptu.from_numpy(self._actions[i:end_idx])

            self._q_curr[i:end_idx] = ptu.get_numpy(qf(bat_s, bat_a))

    # def calculate_terminal_val(self, decay_ratio=0.98):
    #     self._time_step_val = np.zeros_like(self._rewards)
    #     pre_val = 0
    #     for i in reversed(range(self._size)):
    #         if self._terminals[i] == True:
    #             pre_val = 1
    #             self._time_step_val[i] = pre_val
    #         else:
    #             pre_val = decay_ratio * pre_val
    #             self._time_step_val[i] = pre_val

    def rebuild_env_info_dict(self, idx):
        return {
            key: self._env_infos[key][idx]
            for key in self._env_info_keys
        }

    def batch_env_info_dict(self, indices):
        return {
            key: self._env_infos[key][indices]
            for key in self._env_info_keys
        }

    def num_steps_can_sample(self):
        return self._size

    def get_diagnostics(self):
        return OrderedDict([
            ('size', self._size)
        ])

class CustomReplayBuffer(ReplayBuffer):

    def __init__(
        self,
        max_replay_buffer_size,
        observation_dim,
        action_dim,
        env_info_sizes,
    ):
        self._observation_dim = observation_dim
        self._action_dim = action_dim
        self._max_replay_buffer_size = max_replay_buffer_size
        self._observations = np.zeros((max_replay_buffer_size, observation_dim))
        # It's a bit memory inefficient to save the observations twice,
        # but it makes the code *much* easier since you no longer have to
        # worry about termination conditions.
        self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
        self._actions = np.zeros((max_replay_buffer_size, action_dim))
        # Make everything a 2D np array to make it easier for other code to
        # reason about the shape of the data
        self._rewards = np.zeros((max_replay_buffer_size, 1))
        # self._terminals[i] = a terminal was received at time i
        self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
        # Define self._env_infos[key][i] to be the return value of env_info[key]
        # at time i
        self._env_infos = {}
        for key, size in env_info_sizes.items():
            self._env_infos[key] = np.zeros((max_replay_buffer_size, size))
        self._env_info_keys = env_info_sizes.keys()

        self._top = 0
        self._size = 0

    def add_sample(self, observation, action, reward, next_observation,
                   terminal, env_info, **kwargs):
        self._observations[self._top] = observation
        self._actions[self._top] = action
        self._rewards[self._top] = reward
        self._terminals[self._top] = terminal
        self._next_obs[self._top] = next_observation

        for key in self._env_info_keys:
            self._env_infos[key][self._top] = env_info[key]
        self._advance()

    def add_sample_only(self, observation, action, reward, next_observation, terminal):
        self._observations[self._top] = observation
        self._actions[self._top] = action
        self._rewards[self._top] = reward
        self._terminals[self._top] = terminal
        self._next_obs[self._top] = next_observation
        self._advance()

    def terminate_episode(self):
        pass

    def _advance(self):
        self._top = (self._top + 1) % self._max_replay_buffer_size
        if self._size < self._max_replay_buffer_size:
            self._size += 1

    def random_batch(self, batch_size, prob=None, get_idx=False):

        if prob is None:
            indices = np.random.choice(self._size, batch_size, replace=False)
        else:
            # p = 1-D array-like
            # The probabilities associated with each entry in a.
            # If not given, the sample assumes a uniform distribution over all entries in a.
            indices = np.random.choice(self._size, batch_size, replace=False, p=prob)

        batch = dict(
            observations=self._observations[indices],
            actions=self._actions[indices],
            rewards=self._rewards[indices],
            terminals=self._terminals[indices],
            next_observations=self._next_obs[indices],
            diff=self._diff[indices],
        )
        for key in self._env_info_keys:
            assert key not in batch.keys()
            batch[key] = self._env_infos[key][indices]

        if get_idx == True:
            return batch, indices
        else:
            return batch

    def rebuild_env_info_dict(self, idx):
        return {
            key: self._env_infos[key][idx]
            for key in self._env_info_keys
        }

    def batch_env_info_dict(self, indices):
        return {
            key: self._env_infos[key][indices]
            for key in self._env_info_keys
        }

    def num_steps_can_sample(self):
        return self._size

    def get_diagnostics(self):
        return OrderedDict([
            ('size', self._size)
        ])
