from lexa_benchmark.envs.kitchen import KitchenEnv
from collections import OrderedDict
import numpy as np
from gym.spaces import Box, Dict
import mujoco_py

from multiworld.core.serializable import Serializable
from multiworld.envs.env_util import (
    get_stat_in_paths,
    create_stats_ordered_dict,
    get_asset_full_path,
)

from multiworld.envs.mujoco.mujoco_env import MujocoEnv
import copy

from multiworld.core.multitask_env import MultitaskEnv
import matplotlib.pyplot as plt
import os.path as osp
from huge.envs.gymenv_wrapper import GymGoalEnvWrapper
import numpy as np
import gym
import random
import itertools
from itertools import combinations
from envs.base_envs import BenchEnv
from d4rl.kitchen.kitchen_envs import KitchenMicrowaveKettleLightTopLeftBurnerV0
from gym import spaces

def get_kitchen_benchmark_goals():

    object_goal_vals = {#'bottom_burner' :  [-0.88, -0.01],
                        #  'light_switch' :  [ -0.69, -0.05],
                          'slide_cabinet':  [0.37],
                        #  'hinge_cabinet':   [0., 0.5],
                        #  'microwave'    :   [-0.5],
                        #  'kettle'       :   [-0.23, 0.75, 1.62]
                        }

    object_goal_idxs = {#'bottom_burner' :  [9, 10],
                    #'light_switch' :  [17, 18],
                    'slide_cabinet':  [19],
                    #'hinge_cabinet':  [20, 21],
                    #'microwave'    :  [22],
                    #'kettle'       :  [23, 24, 25]
                    }

    base_task_names = [ #'bottom_burner', 
                        #'light_switch', 
                        'slide_cabinet', 
                        #'hinge_cabinet', 
                        #'microwave', 
                        #'kettle' 
                        ]

    
    goal_configs = []
    #single task
    for i in range(1):
      goal_configs.append( [base_task_names[i]])
    """
    #two tasks
    for i,j  in combinations([1,2,3,5], 2) :
      goal_configs.append( [base_task_names[i], base_task_names[j]] )
    """
    obs_element_goals = [] ; obs_element_indices = []
    for objects in goal_configs:
        _goal = np.concatenate([object_goal_vals[obj] for obj in objects])
        _goal_idxs = np.concatenate([object_goal_idxs[obj] for obj in objects])

        obs_element_goals.append(_goal)
        obs_element_indices.append(_goal_idxs)
  
    #print("kitchen benchmarks", obs_element_goals, obs_element_indices, goal_configs)
    return obs_element_goals, obs_element_indices, goal_configs
class Kitchen3DIntermediateEnv(BenchEnv):
  def __init__(self, action_repeat=1, use_goal_idx=False, log_per_goal=False,  control_mode='end_effector', width=64):

    super().__init__(action_repeat, width)
    self.use_goal_idx = use_goal_idx
    self.log_per_goal = log_per_goal

    with self.LOCK:
      self._env =  KitchenMicrowaveKettleLightTopLeftBurnerV0(frame_skip=16, control_mode = control_mode, imwidth=width, imheight=width)

      self._env.sim_robot.renderer._camera_settings = dict(
        distance=3, lookat=[-0.3, .5, 2.], azimuth=90, elevation=-60)

      obs_upper = 8.0 * np.ones(self._env.obs_dim//2)
      obs_lower = -obs_upper
      self._env.observation_space = spaces.Box(obs_lower, obs_upper, dtype=np.float32)
      print("observation space in kitchen", self._env.observation_space)
    self.rendered_goal = True
    self._env.reset()
    self.init_qpos = self._env.sim.data.qpos.copy()
    self.goal_idx = 0
    self.obs_element_goals, self.obs_element_indices, self.goal_configs = get_kitchen_benchmark_goals()
    self.goals = list(range(len(self.obs_element_goals)))
    self.goal = self.generate_goal()
    
  def generate_goal(self):
    return np.array([-0.12, 0.65, 2.6]) #np.random.random(size=(3,))-np.array([0.5,0.5,0.5])+np.array([-1, 0, 2]) # todo: find min max in each dimension

  def internal_extract_state(self, obs):
      return obs[:3]

  def set_goal_idx(self, idx):
    self.goal_idx = idx

  def get_goal_idx(self):
    return self.goal_idx

  def get_goals(self):
    return self.goals

  def render_image(self):
    return self._env.render(mode="rgb_array")

  def render(self):
      return self._env.render(mode="human")
   
  @property
  def state_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})
    #return self._env.observation_space
    return gym.spaces.Box(low=-1,high=1, shape=(3,), dtype=np.float32)
  @property
  def goal_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})
    return gym.spaces.Box(low=-1,high=1, shape=(3,), dtype=np.float32)
    #return self._env.goal_space
  @property
  def action_space(self):
    return self._env.action_space
  @property
  def observation_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})

    observation_space = Dict([
            ('observation', self.state_space),
            ('desired_goal', self.goal_space),
            ('achieved_goal', self.state_space),
            ('state_observation', self.state_space),
            ('state_desired_goal', self.goal_space),
            ('state_achieved_goal', self.state_space),
        ])
    return observation_space

  def _get_obs(self, ):
    #image = self._env.render('rgb_array', width=self._env.imwidth, height =self._env.imheight)
    #obs = {'image': image, 'state': state, 'image_goal': self.render_goal(), 'goal': self.goal}
    #obs = self.internal_extract_state(self._env._get_obs())
    obs = self._env.get_ee_pose()
    goal = self.goal #self._env.goal


    return dict(
            observation=obs,
            desired_goal=goal,
            achieved_goal=obs,
            state_observation=obs,
            state_desired_goal=goal,
            state_achieved_goal=obs
    )

  def step(self, action):
    total_reward = 0.0
    for step in range(self._action_repeat):
      state, reward, done, info = self._env.step(action)
      reward = 0 #self.compute_reward()
      total_reward += reward
      if done:
        break
    obs = self._get_obs()
    for k, v in obs.items():
      if 'metric_' in k:
        info[k] = v
    return obs, total_reward, done, info

  def compute_reward(self, goal=None):
    if goal is None:
      goal = self.goal
    qpos = self._env.sim.data.qpos.copy()

    if len(self.obs_element_indices[goal]) > 9 :
        return  -np.linalg.norm(qpos[self.obs_element_indices[goal]][9:] - self.obs_element_goals[goal][9:])
    else:
        return -np.linalg.norm(qpos[self.obs_element_indices[goal]] - self.obs_element_goals[goal])

  def compute_success(self, goal = None):

    if goal is None:
      goal = self.goal
    qpos = self._env.sim.data.qpos.copy()

    goal_qpos = self.init_qpos.copy()
    goal_qpos[self.obs_element_indices[goal]] = self.obs_element_goals[goal]

    per_obj_success = {
    'bottom_burner' : ((qpos[9]<-0.38) and (goal_qpos[9]<-0.38)) or ((qpos[9]>-0.38) and (goal_qpos[9]>-0.38)),
    'top_burner':    ((qpos[13]<-0.38) and (goal_qpos[13]<-0.38)) or ((qpos[13]>-0.38) and (goal_qpos[13]>-0.38)),
    'light_switch':  ((qpos[17]<-0.25) and (goal_qpos[17]<-0.25)) or ((qpos[17]>-0.25) and (goal_qpos[17]>-0.25)),
    'slide_cabinet' :  abs(qpos[19] - goal_qpos[19])<0.1,
    'hinge_cabinet' :  abs(qpos[21] - goal_qpos[21])<0.2,
    'microwave' :      abs(qpos[22] - goal_qpos[22])<0.2,
    'kettle' : np.linalg.norm(qpos[23:25] - goal_qpos[23:25]) < 0.2
    }
    task_objects = self.goal_configs[goal]

    task_rel_success = 1
    for _obj in task_objects:
      task_rel_success *= per_obj_success[_obj]

    all_obj_success = 1
    for _obj in per_obj_success:
      all_obj_success *= per_obj_success[_obj]

    return int(task_rel_success), int(all_obj_success)

  def render_goal(self):
    if self.rendered_goal:
      return self.rendered_goal_obj

    # random.sample(list(obs_element_goals), 1)[0]
    backup_qpos = self._env.sim.data.qpos.copy()
    backup_qvel = self._env.sim.data.qvel.copy()

    qpos = self.init_qpos.copy()
    qpos[self.obs_element_indices[self.goal]] = self.obs_element_goals[self.goal]

    self._env.set_state(qpos, np.zeros(len(self._env.init_qvel)))

    goal_obs = self._env.render('rgb_array', width=self._env.imwidth, height=self._env.imheight)

    self._env.set_state(backup_qpos, backup_qvel)

    self.rendered_goal = True
    self.rendered_goal_obj = goal_obs
    return goal_obs

  def reset(self):

    with self.LOCK:
      state = self._env.reset()
    if not self.use_goal_idx:
      self.goal_idx = np.random.randint(len(self.goals))
    self.goal = self.generate_goal()#self.goals[self.goal_idx]
    self.rendered_goal = False
    return self._get_obs()

class Kitchen3DGoalEnv(GymGoalEnvWrapper):
    def __init__(self, fixed_start=True, fixed_goal=False, images=False, image_kwargs=None):
        

        env = Kitchen3DIntermediateEnv()
       

        super(Kitchen3DGoalEnv, self).__init__(
            env, observation_key='observation', goal_key='achieved_goal', state_goal_key='state_achieved_goal'
        )
  
    def compute_shaped_distance(self, achieved_state, goal):
        return np.linalg.norm(achieved_state - goal)

    def get_shaped_distance(self, states, goal_states):
        return self.compute_shaped_distance(states, goal_states)

    def render_image(self):
      return self.base_env.render_image()
    
    def get_diagnostics(self, trajectories, desired_goal_states):
        """self._env.observation_space
        Logs things

        Args:
            trajectories: Numpy Array [# Trajectories x Max Path Length x State Dim]
            desired_goal_states: Numpy Array [# Trajectories x State Dim]

        
        endeff_distances = np.array([self.endeff_distance(trajectories[i], np.tile(desired_goal_states[i], (trajectories.shape[1],1))) for i in range(trajectories.shape[0])])
        puck_distances = np.array([self.puck_distance(trajectories[i], np.tile(desired_goal_states[i], (trajectories.shape[1],1))) for i in range(trajectories.shape[0])])

        endeff_movement = self.endeff_distance(trajectories[:,0], trajectories[:, -1])
        puck_movement = self.puck_distance(trajectories[:,0], trajectories[:, -1])
        
        statistics = OrderedDict()self._env.observation_space
            ('final endeff distance', endeff_distances[:,-1]),
            ('puck movement', puck_movement),
            ('endeff movement', endeff_movement),
        ]:
            statistics.update(create_stats_ordered_dict(
                    stat_name,
                    stat,
                    always_show_all_stats=True,
                ))
        
        return statistics
        """
        return OrderedDict()