import numpy as np
from gym.envs.mujoco.swimmer import SwimmerEnv
from gym import register
from safe.envs.safe_env_spec import SafeEnv, interval_barrier
import torch


class SafeSwimmerEnv(SwimmerEnv, SafeEnv):
    def __init__(self, x_threshold=1, y_threshold=1):
        self.x_threshold = x_threshold
        self.y_threshold = y_threshold
        super().__init__()

    def step(self, a):
        ob, reward, done, info = super().step(a)
        if ob[0] >= self.x_threshold:
            # done = True
            # reward = -100
            info['episode.unsafe'] = True
        else:
            info['episode.unsafe'] = False
        return ob, reward, done, info

    def _get_obs(self):
        qpos = self.sim.data.qpos
        qvel = self.sim.data.qvel
        return np.concatenate([qpos.flat, qvel.flat]).astype(np.float32)

    def reset_model(self):
        self.set_state(self.init_qpos, self.init_qvel)
        return self._get_obs()

    def is_state_safe(self, states: torch.Tensor):
        return abs(states[..., 0]) <= self.x_threshold

    def barrier_fn(self, states: torch.Tensor):
        return interval_barrier(states[..., 0], -self.x_threshold, self.x_threshold)

    def reward_fn(self, states: torch.Tensor, actions: torch.Tensor, next_states: torch.Tensor):
        reward_fwd = (next_states[..., 0] - states[..., 0]) / self.dt
        reward_ctrl = -1e-4 * actions.pow(2).sum(dim=-1)
        reward = reward_fwd + reward_ctrl
        return reward


register('SafeSwimmer-v2', entry_point=SafeSwimmerEnv, max_episode_steps=1000)


@torch.no_grad()
def farthest_safe_state(f, s0):
    s0 = s0.clone()
    L, R = -2, s0[0].item()
    for _ in range(100):
        mid = (L + R) / 2
        s0[0] = mid
        if f(s0) >= 1:
            L = mid
        else:
            R = mid
    return mid
