import gym
import numpy as np
from locobot_env.resources.Locobot_interface import LoCoBotInterface
from gym.spaces import Dict
import time

class LocobotEnv(gym.Env):
    def __init__(self):
        self.action_space = gym.spaces.box.Box(
            low=np.array([0.14, -0.15], dtype=np.float32),
            high=np.array([0.35, 0.15], dtype=np.float32))
        
        # self.obs_space = gym.spaces.box.Box(
        #     low=np.array(np.zeros((128, 128, 3)), dtype=np.float32),
        #     high=np.array(np.full((128, 128, 3), 256), dtype=np.float32))

        self.obs_space = gym.spaces.box.Box(
            low=np.array([0, 0, 0, 0], dtype=np.float32),
            high=np.array([1, 1, 1, 1], dtype=np.float32))
        

        self.state_space = self.obs_space

        self.observation_space = Dict([
            ('observation', self.obs_space),
            ('desired_goal', self.obs_space),
            ('achieved_goal', self.obs_space),
            ('state_observation', self.obs_space),
            ('state_desired_goal', self.obs_space),
            ('state_achieved_goal', self.obs_space),
        ])

        self.locobot = LoCoBotInterface()
        self.goal_size = 4

        self.init_goals()

        self.goal_state = None
        self.goal_image = None
        self.sample_goal()

        # Getting ROS images is timely. Store the last states to minimize the number of calls
        trash = self.locobot.get_image_rgb()
        time.sleep(2)
        self.current_state = None    
        self.update_state()

        self.threshold = 0.1

        self.timestep = 0


    def init_goals(self):
        base_path = "/home/locobot/Desktop/new_interface/goalrelabel_locobot_fullgcsl/gcsl/envs/locobot/goals/"

        self.goal_states = []
        self.goal_images = []

        image = np.load(base_path + "goal_two_socks_image.npy")
        state = np.load(base_path + "goal_two_socks_state.npy")

        self.goal_states.append(state)
        self.goal_images.append(image)

        # for i in range(2, 4):
        #     state = np.load(base_path + f"goal_{i}_state.npy")
        #     image = np.load(base_path + f"goal_{i}_image.npy")

        #     self.goal_states.append(state)
        #     self.goal_images.append(image)

    def step(self, action):
        if self.timestep % 2 == 0:
            x, y = action
            x = np.clip(x, self.action_space.low[0], self.action_space.high[0])
            y = np.clip(y, self.action_space.low[1], self.action_space.high[1])

            self.locobot.move_to_point(x, y)
            self.locobot.grab_object()
            # self.locobot.go_rest()
        else:
            x, y = action
            x = np.clip(x, self.action_space.low[0], self.action_space.high[0])
            y = np.clip(y, self.action_space.low[1], self.action_space.high[1])

            self.locobot.move_to_point(x, y)
            self.locobot.leave_object()
            # self.locobot.go_rest()

        self.update_state()

        self.timestep += 1
        reward = 0
        done = False

        return self._get_obs(), reward, done, {}

    def go_rest(self):
        self.locobot.go_rest()

    def reset(self):
        self.timestep = 0
        self.locobot.reset()

        self.update_state()

        return self._get_obs()
        
    def render(self):
        return self.current_image
        
    def render_image(self):
        return self.current_image
    
    def update_state(self):
        previous_state = self.current_state
        self.current_image = self.locobot.get_image_rgb()
        # from PIL import Image
        # im = Image.fromarray(self.current_image )
        # im.save("image.jpeg")

        x1, y1 = self.locobot.get_blue_average_position(self.current_image)
        x2, y2 = self.locobot.get_green_average_position(self.current_image)
        self.current_state = np.array([x1, y1, x2, y2])

        for i in range(self.goal_size):
            if self.current_state[i] is None:
                self.current_state[i] = previous_state[i]

    def seed(self, seed=None): 
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        return [seed]
    
    def interact(self):
        return
    
    def observation(self, state):
        return state
    
    def extract_goal(self, state):
        return state
    
    def _get_obs(self):
        state_obs = self.current_state
        achieved_state_goal = state_obs.copy()
        intended_state_goal = self.goal_state

        obs = state_obs.copy()
    
        achieved_goal = achieved_state_goal.copy()
        intended_goal = intended_state_goal.copy()
            
        return dict(
            observation = obs,
            desired_goal = intended_goal,
            achieved_goal = achieved_goal,
            state_observation = state_obs,
            state_desired_goal = intended_state_goal,
            state_achieved_goal = achieved_state_goal,
        )
    
    def get_shaped_distance(self, state1, state2):
        d = np.linalg.norm(state1[:self.goal_size] - state2[:self.goal_size]) + 5 * (self.goal_size // 2)

        for i in range(0, self.goal_size, 2):
            pos1 = state1[i:i+2]
            pos2 = state2[i:i+2]

            # print(pos1, "vs", pos2, " ==> ", np.linalg.norm(pos1 - pos2))

            if np.linalg.norm(pos1 - pos2) > self.threshold:
                return d
            else:
                d -= 5

        return d
    
    def compute_shaped_distance(self, state1, state2):
        return self.get_shaped_distance(state1, state2)
    
    def sample_goal(self):
        # r = np.random.randint(len(self.goal_states))
        r = 0
        # print("CHANGE GOAL SAMPLING")
        self.goal_state = self.goal_states[r]
        self.goal_image = self.goal_images[r]
        return self.goal_state
        
    def plot_trajectories(self):
        return

    def compute_success(self, state, goal):
        state = state[:self.goal_size]
        goal = goal[:self.goal_size]
        d = 0

        for i in range(0, len(state), 2):
            pos1 = state[i:i+2]
            pos2 = goal[i:i+2]

            if np.linalg.norm(pos1 - pos2) < self.threshold:
                d += 1

        print("compute success", d, state, goal, self.threshold)

        return d