# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Simple grid-world environment.

The task here is to walk to the (max_x, max_y) position in a square grid.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from typing import Any, Dict, Tuple, Union


class GridWalk(object):
  """Walk on grid to target location."""

  def __init__(self, length, tabular_obs = True, start=None):
    """Initializes the environment.

    Args:
      length: The length of the square gridworld.
      tabular_obs: Whether to use tabular observations. Otherwise observations
        are x, y coordinates.
    """
    self._length = length
    self._tabular_obs = tabular_obs
    self._x = np.random.randint(length)
    self._y = np.random.randint(length)
    self._n_state = length ** 2
    self._n_action = 4
    self._target_x = length - 1
    self._target_y = length - 1
    self._start = start if start is not None else length ** 2
    self._start_indices = list(range(self._start))

  def reset(self):
    """Resets the agent to a random square."""
    obs = np.random.choice(self._start_indices)
    (x, y) = self.get_xy_obs(obs)
    self._x = x
    self._y = y

    return self._get_obs()

  def _get_obs(self):
    """Gets current observation."""
    if self._tabular_obs:
      return self._x * self._length + self._y
    else:
      return np.array([self._x, self._y])

  def get_tabular_obs(self, xy_obs):
    """Gets tabular observation given non-tabular (x,y) observation."""
    return self._length * xy_obs[Ellipsis, 0] + xy_obs[Ellipsis, 1]

  def get_xy_obs(self, state):
    """Gets (x,y) coordinates given tabular observation."""
    x = state // self._length
    y = state % self._length
    return np.stack([x, y], axis=-1)

  def _get_matrices(self, policy, state_action=True): 
    dim = self._n_state * self._n_action 
    transitions = np.zeros([dim, dim]) # next state, state 
    rewards = np.zeros([dim])

    mu_states = np.zeros([self._n_state])
    mu_states[self._start_indices] = 1 
    mu_states = mu_states / np.sum(mu_states)

    mu_stateaction = np.array([mu_states for a in range(self._n_action)]).transpose() 
    mu_probs = policy.get_probabilities(mu_states) 
    mu = (mu_stateaction * mu_probs).flatten()

    for x in range(self._length): 
      for y in range(self._length): 
        state = x * self._length + y
        taxi_distance = (np.abs(x - self._target_x) \
            + np.abs(y - self._target_y))
        done = True if taxi_distance == 0 else False
        
        for action in range(self._n_action): 
          idx = state * self._n_action + action 

          if done: 
            next_x = x 
            next_y = y 
            reward = 0 
          else: 
            if action == 0:
              if x < self._length - 1:
                next_x = x + 1
              else: 
                next_x = x
              next_y = y
            elif action == 1:
              if y < self._length - 1:
                next_y = y + 1
              else: 
                next_y = y
              next_x = x
            elif action == 2:
              if x > 0:
                next_x = x - 1 
              else: 
                next_x = x 
              next_y = y 
            elif action == 3:
              if y > 0:
                next_y = y - 1 
              else: 
                next_y = y 
              next_x = x

            taxi_distance = (np.abs(next_x - self._target_x) \
            + np.abs(next_y - self._target_y))
            reward = np.exp(-2. * taxi_distance / self._length)
          
          next_state = next_x * self._length + next_y
          rewards[idx] = reward 

          for next_action in range(self._n_action): 
            next_idx = next_state * self._n_action + next_action 
            policy_prob = policy.get_probability(next_state, next_action)
            transitions[next_idx, idx] = policy_prob 
    
    return mu, rewards, transitions


  def step(self, action):
    """Perform a step in the environment.

    Args:
      action: A valid action (one of 0, 1, 2, 3).

    Returns:
      next_obs: Observation after action is applied.
      reward: Environment step reward.
      done: Whether the episode has terminated.
      info: A dictionary of additional environment information.

    Raises:
      ValueError: If the input action is invalid.
    """
    taxi_distance = (np.abs(self._x - self._target_x) +
                     np.abs(self._y - self._target_y))

    done = False 
    if taxi_distance == 0: 
      done = True 
      return self._get_obs(), 0, done, {} 
        
    if action == 0:
      if self._x < self._length - 1:
        self._x += 1
    elif action == 1:
      if self._y < self._length - 1:
        self._y += 1
    elif action == 2:
      if self._x > 0:
        self._x -= 1
    elif action == 3:
      if self._y > 0:
        self._y -= 1
    else:
      raise ValueError('Invalid action %s.' % action)

    taxi_distance = (np.abs(self._x - self._target_x) +
                     np.abs(self._y - self._target_y))
    reward = np.exp(-2. * taxi_distance / self._length)
    
    return self._get_obs(), reward, done, {}

  @property
  def num_states(self):
    return self._n_state  # pytype: disable=bad-return-type  # bind-properties

  @property
  def num_actions(self):
    return self._n_action

  @property
  def state_dim(self):
    return 1 if self._tabular_obs else 2

  @property
  def action_dim(self):
    return self._n_action
