from dependencies.ravens.ravens.environments.environment import EnvironmentNoRotationsWithHeightmap
from dependencies.ravens.ravens.tasks.align_box_corner import AlignBoxCorner
from lexa_benchmark.envs.kitchen import KitchenEnv
from collections import OrderedDict
import numpy as np
from gym.spaces import Box, Dict
import mujoco_py

from multiworld.core.serializable import Serializable
from multiworld.envs.env_util import (
    get_stat_in_paths,
    create_stats_ordered_dict,
    get_asset_full_path,
)

from multiworld.envs.mujoco.mujoco_env import MujocoEnv
import copy

from multiworld.core.multitask_env import MultitaskEnv
from ravens.environments.environment import Environment
import matplotlib.pyplot as plt
import os.path as osp
from huge.envs.gymenv_wrapper import GymGoalEnvWrapper
import numpy as np
import gym
import random
import itertools
from itertools import combinations
from gym import spaces
from huge.envs.env_utils import Discretized

import pybullet as p




class RavensEnv():
  def __init__(self,
               disp=False,
               shared_memory=False,
               hz=240,
               use_egl=False):

    assets_root = "./ravens/environments/assets/"
    task = AlignBoxCorner()

    self._env = Environment(assets_root,
               task,
               disp,
               shared_memory,
               hz,
               use_egl)
    

    # TODO: adjust
    # TODO: how do I get the state of suction or not?
    obs_upper = 1.0 * np.ones(7)
    obs_lower = -obs_upper
    obs_upper_pose = 1.0 * np.ones(7)
    obs_lower_pose = -obs_upper_pose
    self._observation_space = spaces.Box(np.concatenate([obs_lower, obs_lower_pose]),np.concatenate([obs_upper, obs_upper_pose]), dtype=np.float32)
    self.goal_space = spaces.Box(np.concatenate([obs_lower, obs_lower_pose]),np.concatenate([obs_upper, obs_upper_pose]), dtype=np.float32)
   
    self.action_space = gym.spaces.Dict({
        'move_cmd':
            gym.spaces.Tuple(
                (self._env.position_bounds,
                 gym.spaces.Box(-1.0, 1.0, shape=(4,), dtype=np.float32))),
        'suction_cmd': gym.spaces.Discrete(2),  # Binary 0-1.
    })


    self.ee_init_pos = [0.4831041007489618, 0.029937637798535994, 0.34059017863897467, -0.00015432182101458617, 1.3429905759739741e-05, 0.16335377683469907, 0.9865675443669595]

  def step(self, action=None):
      if action is not None:
          action['acts_left'] = 0

      return self._env.step(action)

  @property
  def state_space(self):
    #shape = self._size + (p.linalg.norm(state - goal) < self.goal_threshold
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})
    return self.goal_space

  @property
  def observation_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})

    observation_space = Dict([
            ('observation', self.state_space),
            ('desired_goal', self.goal_space),
            ('achieved_goal', self.state_space),
            ('state_observation', self.state_space),
            ('state_desired_goal', self.goal_space),
            ('state_achieved_goal', self.state_space),
        ])
    return observation_space

  def get_world_obs(self, ):
    obs = []
    for ids in self._env.obj_ids['rigid']:
        pos, orientation = p.getBasePositionAndOrientation(ids)
        obs.append(pos)
        obs.append(orientation)

    obs = np.concatenate(obs)

    goal = []
    for ids in self._env.obj_ids['fixed']:
        pos, orientation = p.getBasePositionAndOrientation(ids)
        goal.append(pos)
        goal.append(orientation)
    goal = np.concatenate(goal)

    return obs, goal
        
  def _get_obs(self, ):
    #image = self._env.render('rgb_array', width=self._env.imwidth, height =self._env.imheight)
    #obs = {'image': image, 'state': state, 'image_goal': self.render_goal(), 'goal': self.goal}'
    world_obs, world_goal = self.get_world_obs()
    ee_obs = np.concatenate(self._env.get_ee_pose())
    obs = np.concatenate([world_obs, ee_obs])
    goal = np.concatenate([world_goal, self.ee_init_pos]) #self._env.goal

    return dict(
            observation=obs,
            desired_goal=goal,
            achieved_goal=obs,
            state_observation=obs,
            state_desired_goal=goal,
            state_achieved_goal=obs
    )

class RavensGoalEnv(GymGoalEnvWrapper):
    def __init__(self,
               disp=False,
               shared_memory=False,
               hz=240,
               use_egl=False):

        env = RavensEnv(
                disp,
               shared_memory,
               hz,
               use_egl)
       

        super(RavensGoalEnv, self).__init__(
            env, observation_key='observation', goal_key='achieved_goal', state_goal_key='state_achieved_goal'
        )

    def compute_success(self, achieved_state, goal):        
      return np.linalg.norm(achieved_state - goal) < 0.05
      #return int(per_obj_success['slide_cabinet'])  + #int(per_obj_success['hinge_cabinet'])+ int(per_obj_success['microwave'])

    def goal_distance(self, state, goal_state):
        # Uses distance in state_goal_key to determine distance (useful for images)
        achieved_state = self.observation(state)

        return self.compute_shaped_distance(achieved_state, None)
  
    # The task is to open the microwave, then open the slider and then open the cabinet
    def compute_shaped_distance(self, achieved_state, goal):
        return np.linalg.norm(achieved_state - goal)

    def get_shaped_distance(self, states, goal_states):
        return self.compute_shaped_distance(states, goal_states)

    def render_image(self):
      return self.base_env.render_image()
    
    def get_diagnostics(self, trajectories, desired_goal_states):
        return OrderedDict()