from lexa_benchmark.envs.kitchen import KitchenEnv
from collections import OrderedDict
import numpy as np
from gym.spaces import Box, Dict
import mujoco_py

from multiworld.core.serializable import Serializable
from multiworld.envs.env_util import (
    get_stat_in_paths,
    create_stats_ordered_dict,
    get_asset_full_path,
)

from multiworld.envs.mujoco.mujoco_env import MujocoEnv
import copy

from multiworld.core.multitask_env import MultitaskEnv
import matplotlib.pyplot as plt
import os.path as osp
from huge.envs.gymenv_wrapper import GymGoalEnvWrapper
import numpy as np
import gym
import random
import itertools
from itertools import combinations
from envs.base_envs import BenchEnv
from d4rl.kitchen.kitchen_envs import KitchenMicrowaveKettleLightTopLeftBurnerV0
from gym import spaces
import torch

OBJECT_GOAL_VALS = { 
                          'slide_cabinet':  [0.37],
                          'hinge_cabinet':   [0.25],#[1.45],
                          'microwave'    :   [-0.75],
                        }
OBJECT_KEY_POS = {  
                    'slide_cabinet':  [-0.12, 0.65, 2.6],
                    'hinge_cabinet':  [-0.53, 0.65, 2.6],
                    'microwave'    :  [-0.63, 0.48, 1.8],
                    }
FINAL_KEY_POS = { 
                    'slide_cabinet':  [0.2, 0.65, 2.6],
                    'hinge_cabinet':  [-0.45, 0.53, 2.6],
                    'microwave'    :  [-0.7, 0.38, 1.8],
                    }
OBJECT_GOAL_IDXS = {
                    'slide_cabinet':  [2],
                    'hinge_cabinet':  [3],
                    'microwave'    :  [4],
                    }

INITIAL_STATE = np.array([4.79267505e-02,  3.71350919e-02, 
       -4.65501369e-04, -6.44129196e-03, -1.77048263e-03])


    

BASE_TASK_NAMES = [   'bottom_burner', 
                        'light_switch', 
                        'slide_cabinet', 
                        'hinge_cabinet', 
                        'microwave', 
                        #'kettle' 
                  ]


class KitchenIntermediateEnv(BenchEnv):
  def __init__(self, goal_name='slide_cabinet', action_repeat=1, use_goal_idx=False, log_per_goal=False,  control_mode='end_effector', width=64):

    super().__init__(action_repeat, width)
    self.use_goal_idx = use_goal_idx
    self.log_per_goal = log_per_goal

    with self.LOCK:
      self._env =  KitchenMicrowaveKettleLightTopLeftBurnerV0(frame_skip=16, control_mode = control_mode, imwidth=width, imheight=width)

      self._env.sim_robot.renderer._camera_settings = dict(
        distance=3, lookat=[-0.3, .5, 2.], azimuth=90, elevation=-60)

      obs_upper = 8.0 * np.ones(5)
      obs_lower = -obs_upper
      obs_upper_pose = 4 * np.ones(3)
      obs_lower_pose = -obs_upper_pose
      self._observation_space = spaces.Box(np.concatenate([obs_lower, obs_lower_pose]),np.concatenate([obs_upper, obs_upper_pose]), dtype=np.float32)
      self._goal_space = spaces.Box(np.concatenate([obs_lower, obs_lower_pose]),np.concatenate([obs_upper, obs_upper_pose]), dtype=np.float32)
      print("observation space in kitchen", self._observation_space)
   
    self.goal_name = goal_name
    initial_obs = self.reset()

    print("initial obs", initial_obs)
    print("goal_name ", goal_name)

    
  def generate_goal(self,):
    initial_obs = np.array([4.79267505e-02,  3.71350919e-02, 
       -4.65501369e-04, -6.44129196e-03, -1.77048263e-03])
    #self.goal_name =  'hinge_cabinet' #'slide_cabinet'#'slide_cabinet' #BASE_TASK_NAMES[random.randint(len(BASE_TASK_NAMES))]
    hook_pose = FINAL_KEY_POS[self.goal_name] #np.array([-0.12, 0.65, 2.6]) #np.random.random(size=(3,))-np.array([0.5,0.5,0.5])+np.array([-1, 0, 2]) # todo: find min max in each dimension
    goal_state = initial_obs
    goal_state[OBJECT_GOAL_IDXS[self.goal_name]] = OBJECT_GOAL_VALS[self.goal_name]
    final_goal = np.concatenate([goal_state, hook_pose])
    return final_goal

  def internal_extract_obs(self, obs):
      gripper_pos = obs[7:9]
      slide_cabinet_joint = [obs[19]]
      hinge_cabinet_joint = [obs[21]]
      microwave_joint = [obs[22]]
      return np.concatenate([gripper_pos, slide_cabinet_joint, hinge_cabinet_joint, microwave_joint])

  def render_image(self):
    return self._env.render(mode="rgb_array")

  def render(self):
      return self._env.render(mode="human")
   
  @property
  def state_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})
    return self._observation_space
  @property
  def goal_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})
    return self._goal_space
  @property
  def action_space(self):
    return self._env.action_space
  @property
  def observation_space(self):
    #shape = self._size + (3,)
    #space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
    #return gym.spaces.Dict({'image': space})

    observation_space = Dict([
            ('observation', self.state_space),
            ('desired_goal', self.goal_space),
            ('achieved_goal', self.goal_space),
            ('state_observation', self.state_space),
            ('state_desired_goal', self.state_space),
            ('state_achieved_goal', self.state_space),
        ])
    return observation_space

  def _get_obs(self, ):
    #image = self._env.render('rgb_array', width=self._env.imwidth, height =self._env.imheight)
    #obs = {'image': image, 'state': state, 'image_goal': self.render_goal(), 'goal': self.goal}'
    world_obs = self.internal_extract_obs(self._env._get_obs())
    ee_obs = self._env.get_ee_pose()
    obs = np.concatenate([world_obs, ee_obs])
    goal = self.goal #self._env.goal

    return dict(
            observation=obs,
            desired_goal=goal,
            achieved_goal=obs,
            state_observation=obs,
            state_desired_goal=goal,
            state_achieved_goal=obs
    )

  def step(self, action):
    total_reward = 0.0
    for step in range(self._action_repeat):
      state, reward, done, info = self._env.step(action)
      reward = 0 #self.compute_reward()
      total_reward += reward
      if done:
        break
    obs = self._get_obs()
    for k, v in obs.items():
      if 'metric_' in k:
        info[k] = v
    return obs, total_reward, done, info

  def reset(self):

    with self.LOCK:
      state = self._env.reset()
    self.goal = self.generate_goal()#self.goals[self.goal_idx]
    return self._get_obs()

class KitchenGoalEnv(GymGoalEnvWrapper):
    def __init__(self, goal_name='slide_cabinet', fixed_start=True, fixed_goal=False, images=False, image_kwargs=None):
        

        env = KitchenIntermediateEnv(goal_name)
       

        super(KitchenGoalEnv, self).__init__(
            env, observation_key='observation', goal_key='achieved_goal', state_goal_key='state_achieved_goal'
        )

    def compute_success(self, achieved_state, goal):        
      per_obj_success = {
          #'bottom_burner' : ((achieved_state[2]<-0.38) and (goal[2]<-0.38)) or ((achieved_state[2]>-0.38) and (goal[2]>-0.38)),
          #'top_burner':    ((achieved_state[15]<-0.38) and (goal[6]<-0.38)) or ((achieved_state[6]>-0.38) and (goal[6]>-0.38)),
          #'light_switch':  ((achieved_state[10]<-0.25) and (goal[10]<-0.25)) or ((achieved_state[10]>-0.25) and (goal[10]>-0.25)),
          'slide_cabinet' :  abs(achieved_state[2] - OBJECT_GOAL_VALS['slide_cabinet'])<0.1,
          'hinge_cabinet' :  abs(achieved_state[3] - OBJECT_GOAL_VALS['hinge_cabinet'])<0.2,
          'microwave' :      abs(achieved_state[4] - OBJECT_GOAL_VALS['microwave'])<0.2,
          #'kettle' : np.linalg.norm(achieved_state[16:18] - goal[16:18]) < 0.2
      }

      return per_obj_success[self.base_env.goal_name]
  
    def success_distance(self, achieved_state, goal):
        per_obj_distance = {
          'slide_cabinet' :  abs(achieved_state[2] - OBJECT_GOAL_VALS['slide_cabinet']),
          'hinge_cabinet' :  abs(achieved_state[3] - OBJECT_GOAL_VALS['hinge_cabinet']),
          'microwave' :      abs(achieved_state[4] - OBJECT_GOAL_VALS['microwave']),
        }

        return per_obj_distance[self.base_env.goal_name]

    def compute_shaped_distance(self, achieved_state, goal):
        goal_name = self.base_env.goal_name 

        goal_idxs = OBJECT_GOAL_IDXS[goal_name]
        achieved_joint = achieved_state[goal_idxs]
        goal_joint = goal[goal_idxs]
        original_joint = INITIAL_STATE[goal_idxs]

        distance_from_original = abs(original_joint -  achieved_joint)

        dist_slide = abs(achieved_joint-goal_joint)
        key_position = OBJECT_KEY_POS[goal_name]
  
        distance_to_key_pos = np.linalg.norm(achieved_state[-3:]-key_position)

        if distance_from_original < 0.03 and distance_to_key_pos > 0.05:

          gripper_open = np.linalg.norm(achieved_state[:2]-np.array([1,1]))
          return distance_to_key_pos + gripper_open + dist_slide + 2
        else:
          gripper_closed = np.linalg.norm(achieved_state[:2]-np.array([0,0]))
          return dist_slide #+ gripper_closed

       

    def get_shaped_distance(self, states, goal_states):
        return self.compute_shaped_distance(states, goal_states)

    def render_image(self):
      return self.base_env.render_image()
    
    def get_diagnostics(self, trajectories, desired_goal_states):
 
        return OrderedDict()