# coding=utf-8
# Copyright 2022 The Reincarnating Rl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A subclass of the cirular replay buffer to handle Monte Carlo rollouts."""


from dopamine.replay_memory import circular_replay_buffer as crb
from dopamine.replay_memory import sum_tree
from dopamine.replay_memory.circular_replay_buffer import ReplayElement
import gin
import numpy as np
from six.moves import range
from six.moves import zip

UNAVAILABLE_MC_RETURN = -1e8


@gin.configurable
class OutOfGraphReplayBufferWithMC(crb.OutOfGraphReplayBuffer):
  """This is an extension of the circular RB that handles Monte Carlo rollouts.

  Specifically, it supports two kinds of sampling:
  - Regular n-step sampling.
  - MonteCarlo rollout sampling (for a different value of n).
  """

  def __init__(self,
               observation_shape,
               stack_size,
               replay_capacity=1000000,
               batch_size=32,
               update_horizon=1,
               gamma=0.99,
               max_sample_attempts=1000,
               extra_storage_types=None,
               observation_dtype=np.uint8,
               monte_carlo_rollout_length=10,
               reverse_fill=True):
    """Initializes OutOfGraphReplayBufferWithMC.

    Note that not all constructor parameters are replicated here. The rest can
    be set via a gin config file.

    Args:
      observation_shape: tuple of ints.
      stack_size: int, number of frames to use in state stack.
      replay_capacity: int, number of transitions to keep in memory.
      batch_size: int.
      update_horizon: int, length of update ('n' in n-step update).
      gamma: int, the discount factor.
      max_sample_attempts: int, the maximum number of attempts allowed to
        get a sample.
      extra_storage_types: list of ReplayElements defining the type of the extra
        contents that will be stored and returned by sample_transition_batch.
      observation_dtype: np.dtype, type of the observations. Defaults to
        np.uint8 for Atari 2600.
      monte_carlo_rollout_length: int, number of transitions to sample for the
        Monte Carlo rollout.
      reverse_fill: bool, specifies whether we reverse-fill the returns upon
        finishing an episode, or whether we do n-step rollouts (where
        n=monte_carlo_rollout_length) for computing the MonteCarlo returns.
    """
    self._monte_carlo_rollout_length = monte_carlo_rollout_length
    self._reverse_fill = reverse_fill
    if extra_storage_types is None:
      extra_storage_types = []
    extra_storage_types += [
        ReplayElement('monte_carlo_reward', (), np.float32)
    ]

    super(OutOfGraphReplayBufferWithMC, self).__init__(
        observation_shape,
        stack_size,
        replay_capacity,
        batch_size,
        update_horizon=update_horizon,
        gamma=gamma,
        max_sample_attempts=max_sample_attempts,
        extra_storage_types=extra_storage_types,
        observation_dtype=observation_dtype)

  def sample_index_batch(self, batch_size):
    """Returns a batch of valid indices sampled uniformly.

    Args:
      batch_size: int, number of indices returned.

    Returns:
      list of ints, a batch of valid indices sampled uniformly.

    Raises:
      RuntimeError: If the batch was not constructed after maximum number of
        tries.
    """
    if self._reverse_fill:
      horizon = self._update_horizon
    else:
      horizon = max(self._update_horizon, self._monte_carlo_rollout_length)
    if self.is_full():
      # add_count >= self._replay_capacity > self._stack_size
      min_id = self.cursor() - self._replay_capacity + self._stack_size - 1
      max_id = self.cursor() - horizon
    else:
      # add_count < self._replay_capacity
      min_id = self._stack_size - 1
      max_id = self.cursor() - horizon
      if max_id <= min_id:
        raise RuntimeError('Cannot sample a batch with fewer than stack size '
                           '({}) + horizon ({}) transitions.'.
                           format(self._stack_size, horizon))

    indices = []
    attempt_count = 0
    while (len(indices) < batch_size and
           attempt_count < self._max_sample_attempts):
      attempt_count += 1
      index = np.random.randint(min_id, max_id) % self._replay_capacity
      if self.is_valid_transition(index):
        indices.append(index)
    if len(indices) != batch_size:
      raise RuntimeError(
          'Max sample attempts: Tried {} times but only sampled {}'
          ' valid indices. Batch size is {}'.
          format(self._max_sample_attempts, len(indices), batch_size))

    return indices

  def sample_transition_batch(self, batch_size=None, indices=None):
    """Returns a batch of transitions (including any extra contents).

    There are two different horizons being considered here, one for the regular
    transitions, and one for doing Monte Carlo rollouts for estimating returns.

    Args:
      batch_size: int, number of transitions returned. If None, the default
        batch_size will be used.
      indices: None or list of ints, the indices of every transition in the
        batch. If None, sample the indices uniformly.

    Returns:
      transition_batch: tuple of np.arrays with the shape and type as in
        get_transition_elements().

    Raises:
      ValueError: If an element to be sampled is missing from the replay buffer.
    """
    if batch_size is None:
      batch_size = self._batch_size
    if indices is None:
      indices = self.sample_index_batch(batch_size)
    assert len(indices) == batch_size

    transition_elements = self.get_transition_elements(batch_size)
    batch_arrays = self._create_batch_arrays(batch_size)
    for batch_element, state_index in enumerate(indices):
      # Get transitions for regular updates.
      trajectory_indices = [(state_index + j) % self._replay_capacity
                            for j in range(self._update_horizon)]
      trajectory_terminals = self._store['terminal'][trajectory_indices]
      is_terminal_transition = trajectory_terminals.any()
      if not is_terminal_transition:
        trajectory_length = self._update_horizon
      else:
        # np.argmax of a bool array returns the index of the first True.
        trajectory_length = np.argmax(trajectory_terminals.astype(bool),
                                      0) + 1
      next_state_index = state_index + trajectory_length
      trajectory_discount_vector = (
          self._cumulative_discount_vector[:trajectory_length])
      trajectory_rewards = self.get_range(self._store['reward'], state_index,
                                          next_state_index)

      if not self._reverse_fill:
        # Get transitions for Monte Carlo rollouts.
        monte_carlo_indices = [(state_index + j) % self._replay_capacity
                               for j in range(self._monte_carlo_rollout_length)]
        monte_carlo_terminals = self._store['terminal'][monte_carlo_indices]
        is_monte_carlo_terminal_transition = monte_carlo_terminals.any()
        if not is_monte_carlo_terminal_transition:
          monte_carlo_length = self._monte_carlo_rollout_length
        else:
          # np.argmax of a bool array returns the index of the first True.
          monte_carlo_length = np.argmax(monte_carlo_terminals.astype(bool),
                                         0) + 1
        next_state_monte_carlo_index = state_index + trajectory_length
        monte_carlo_discount_vector = (
            self._cumulative_discount_vector[:monte_carlo_length])
        monte_carlo_rewards = self.get_range(self._store['reward'], state_index,
                                             next_state_monte_carlo_index)

      # Fill the contents of each array in the sampled batch.
      assert len(transition_elements) == len(batch_arrays)
      for element_array, element in zip(batch_arrays, transition_elements):
        if element.name == 'state':
          element_array[batch_element] = self.get_observation_stack(state_index)
        elif element.name == 'reward':
          # compute the discounted sum of rewards in the trajectory.
          element_array[batch_element] = trajectory_discount_vector.dot(
              trajectory_rewards)
        elif element.name == 'monte_carlo_reward' and not self._reverse_fill:
          # compute the discounted sum of rewards in the trajectory.
          element_array[batch_element] = monte_carlo_discount_vector.dot(
              monte_carlo_rewards)
        elif element.name == 'next_state':
          element_array[batch_element] = self.get_observation_stack(
              (next_state_index) % self._replay_capacity)
        elif element.name == 'terminal':
          element_array[batch_element] = is_terminal_transition
        elif element.name == 'indices':
          element_array[batch_element] = state_index
        elif element.name in list(self._store.keys()):
          element_array[batch_element] = (
              self._store[element.name][state_index])
        # We assume the other elements are filled in by the subclass.

    return batch_arrays

  def _record_monte_carlo_returns(self):
    cursor = (self.cursor() - 1) % self._replay_capacity
    accumulated_returns = self._store['reward'][cursor]
    self._store['monte_carlo_reward'][cursor] = accumulated_returns
    cursor = (cursor - 1) % self._replay_capacity
    while self._store['terminal'][cursor] != 1:
      accumulated_returns *= self._gamma
      accumulated_returns += self._store['reward'][cursor]
      self._store['monte_carlo_reward'][cursor] = accumulated_returns
      cursor = (cursor - 1) % self._replay_capacity

  def add(self,
          observation,
          action,
          reward,
          terminal,
          *args,
          priority=None,
          episode_end=False):
    """Adds a transition to the replay memory.

    Args:
      observation: np.array with shape observation_shape.
      action: int, the action in the transition.
      reward: float, the reward received in the transition.
      terminal: A uint8 acting as a boolean indicating whether the transition
                was terminal (1) or not (0).
      *args: extra contents with shapes and dtypes according to
        extra_storage_types.
      priority: float, unused in the circular replay buffer, but may be used
        in child classes like PrioritizedReplayBuffer.
      episode_end: bool, whether this experience is the last experience in
        the episode. This is useful for tasks that terminate due to time-out,
        but do not end on a terminal state. Overloading 'terminal' may not
        be sufficient in this case, since 'terminal' is passed to the agent
        for training. 'episode_end' allows the replay buffer to determine
        episode boundaries without passing that information to the agent.
    """
    if priority is not None:
      args = args + (priority,)

    # We pass in reward twice to satisfy requirements (since it's expecting an
    # extra element for the Monte Carlo returns).
    self._check_add_types(observation, action, reward, terminal, reward, *args)
    if self.is_empty() or self._store['terminal'][self.cursor() - 1] == 1:
      for _ in range(self._stack_size - 1):
        # Child classes can rely on the padding transitions being filled with
        # zeros. This is useful when there is a priority argument.
        self._add_zero_transition()
    if episode_end or terminal:
      self._episode_end_indices.add(self.cursor())
      self._next_experience_is_episode_start = True
    else:
      self._episode_end_indices.discard(self.cursor())  # If present

    # Setting MC return to an arbitrarly large negative number to indicate
    # the fact the MC returns are not filled yet for this state. This change
    # is needed to use MC replay buffer in the online setting where 5 args
    # are expected at each time step.
    self._add(
        observation, action, reward, terminal, UNAVAILABLE_MC_RETURN, *args)
    if terminal and self._reverse_fill:
      self._record_monte_carlo_returns()


@gin.configurable
class OutOfGraphPrioritizedReplayBufferwithMC(OutOfGraphReplayBufferWithMC):
  """A prioritized replay buffer with monte-carlo support."""

  def __init__(self,
               observation_shape,
               stack_size,
               replay_capacity,
               batch_size,
               update_horizon=1,
               gamma=0.99,
               observation_dtype=np.uint8,
               reverse_fill=True):
    """Initializes OutOfGraphPrioritizedReplayBufferwithMC.

    Args:
      observation_shape: tuple of ints.
      stack_size: int, number of frames to use in state stack.
      replay_capacity: int, number of transitions to keep in memory.
      batch_size: int.
      update_horizon: int, length of update ('n' in n-step update).
      gamma: int, the discount factor.
      observation_dtype: np.dtype, type of the observations. Defaults to
        np.uint8 for Atari 2600.
      reverse_fill: bool, specifies whether we reverse-fill the returns upon
        finishing an episode, or whether we do n-step rollouts (where
        n=monte_carlo_rollout_length) for computing the MonteCarlo returns.

    See prioritized_replay_buffer.py for more details.
    """
    super().__init__(
        observation_shape=observation_shape,
        stack_size=stack_size,
        replay_capacity=replay_capacity,
        batch_size=batch_size,
        update_horizon=update_horizon,
        gamma=gamma,
        observation_dtype=observation_dtype,
        reverse_fill=reverse_fill)

    self.sum_tree = sum_tree.SumTree(replay_capacity)

  def get_add_args_signature(self):
    """The signature of the add function."""
    parent_add_signature = super().get_add_args_signature()
    add_signature = parent_add_signature + [
        ReplayElement('priority', (), np.float32)
    ]
    return add_signature

  def _add(self, *args):
    """Internal add method to add to the underlying memory arrays."""
    self._check_args_length(*args)

    # Use Schaul et al.'s (2015) scheme of setting the priority of new elements
    # to the maximum priority so far.
    # Picks out 'priority' from arguments and adds it to the sum_tree.
    transition = {}
    for i, element in enumerate(self.get_add_args_signature()):
      if element.name == 'priority':
        priority = args[i]
      else:
        transition[element.name] = args[i]

    self.sum_tree.set(self.cursor(), priority)
    super()._add_transition(transition)

  def sample_index_batch(self, batch_size):
    """Returns a batch of valid indices sampled as in Schaul et al. (2015)."""
    # Sample stratified indices. Some of them might be invalid.
    indices = self.sum_tree.stratified_sample(batch_size)
    allowed_attempts = self._max_sample_attempts
    for i in range(len(indices)):
      if not self.is_valid_transition(indices[i]):
        if allowed_attempts == 0:
          raise RuntimeError(
              'Max sample attempts: Tried {} times but only sampled {}'
              ' valid indices. Batch size is {}'.
              format(self._max_sample_attempts, i, batch_size))
        index = indices[i]
        while not self.is_valid_transition(index) and allowed_attempts > 0:
          # If index i is not valid keep sampling others. Note that this
          # is not stratified.
          index = self.sum_tree.sample()
          allowed_attempts -= 1
        indices[i] = index
    return indices

  def sample_transition_batch(self, batch_size=None, indices=None):
    """Returns a batch of transitions with extra storage and the priorities."""
    transition = super().sample_transition_batch(batch_size, indices)
    transition_elements = self.get_transition_elements(batch_size)
    transition_names = [e.name for e in transition_elements]
    probabilities_index = transition_names.index('sampling_probabilities')
    indices_index = transition_names.index('indices')
    indices = transition[indices_index]
    # The parent returned an empty array for the probabilities. Fill it with the
    # contents of the sum tree.
    transition[probabilities_index][:] = self.get_priority(indices)
    return transition

  def set_priority(self, indices, priorities):
    """Sets the priority of the given elements according to Schaul et al."""
    assert indices.dtype == np.int32, ('Indices must be integers, '
                                       'given: {}'.format(indices.dtype))
    for index, priority in zip(indices, priorities):
      self.sum_tree.set(index, priority)

  def get_priority(self, indices):
    """Fetches the priorities correspond to a batch of memory indices."""
    assert indices.shape, 'Indices must be an array.'
    assert indices.dtype == np.int32, ('Indices must be int32s, '
                                       'given: {}'.format(indices.dtype))
    batch_size = len(indices)
    priority_batch = np.empty((batch_size), dtype=np.float32)
    for i, memory_index in enumerate(indices):
      priority_batch[i] = self.sum_tree.get(memory_index)
    return priority_batch

  def get_transition_elements(self, batch_size=None):
    """Returns a 'type signature' for sample_transition_batch."""
    parent_transition_type = super().get_transition_elements(batch_size)
    probablilities_type = [
        ReplayElement('sampling_probabilities', (batch_size,), np.float32)
    ]
    return parent_transition_type + probablilities_type
