import safety_gym
import gym
from safety_gym.envs.engine import Engine
import torch
import torch.nn as nn
import numpy as np

# 1. randomize_layout = False
# 2. change observation
# 3. define safe/barrier

from safe.envs.safe_env_spec import SafeEnv


def from_mat3_to_angle(mat):
    return np.arctan2(mat[0, 0], mat[0, 1]).astype(np.float32)


def from_angle_to_mat3(angle):
    cos, sin = np.cos(angle), np.sin(angle)
    return np.array([
        [sin, cos],
        [-cos, sin],
    ]).astype(np.float32)


class MyEngine(Engine, SafeEnv):
    metadata = {'render.modes': ['human', 'rgb_array']}

    def __init__(self, config):
        super().__init__(config)
        self.old_observation_space = self.observation_space
        self.new_observation_space = gym.spaces.Box(
            -np.inf, np.inf, [self.old_observation_space.shape[0] + 3])
        self.observation_space = self.new_observation_space

    def extra_obs(self):
        robot_pos = self.world.robot_pos()[:2]
        robot_mat = self.world.robot_mat()
        mat_theta = from_mat3_to_angle(robot_mat)
        return np.r_[robot_pos, mat_theta]

    def obs(self):
        self.observation_space = self.old_observation_space
        obs = super().obs()  # Engine checks the validity of observation space
        self.observation_space = self.new_observation_space
        obs = np.r_[obs, self.extra_obs()]   # add my modification
        return obs

    def step(self, action):
        next_obs, reward, done, info = super().step(action)
        if info['cost'] > 0:
            info['episode.unsafe'] = True
            reward -= 10
            done = True
        else:
            info['episode.unsafe'] = False
        return next_obs, reward, done, info

    def is_state_safe(self, states: torch.Tensor):
        dist_to_min_hazards = (1 - states[..., 22:38].max(dim=-1).values) * self.lidar_max_dist
        return dist_to_min_hazards <= self.hazards_size

    def barrier_fn(self, states: torch.Tensor):
        dist_to_min_hazards = (1 - states[..., 22:38].max(dim=-1).values) * self.lidar_max_dist
        return ((self.hazards_size - dist_to_min_hazards) * 100 + 1).clamp(min=0)

    def reward_fn(self, states: torch.Tensor, actions: torch.Tensor, next_states: torch.Tensor):
        n = len(states)
        return torch.zeros(n, device=states.device)


class SlimEngine(MyEngine):
    pass


from gym.envs.registration import registry
from copy import deepcopy

env_id = 'Safexp-PointGoal1-v0'
config = deepcopy(registry.spec(env_id)._kwargs)
config['config']['randomize_layout'] = False

gym.register('MySafexp-PointGoal1-v0', entry_point=MyEngine, kwargs=config, max_episode_steps=1000)
