import copy
import warnings

from typing import Union, List

import gym
import numpy as np
from pypoman import compute_polytope_halfspaces
from action_masking.sb3_contrib.common.utils import fetch_fn
from action_masking.util.util import ContMaskingMode
from action_masking.util.sets import Zonotope, normalize_zonotope

# Needed for calculating U_eq (center of action space)
GRAVITY = 9.81
K = 0.89/1.4

class ActionMaskingWrapper(gym.Wrapper):
    """
    :param env: Gym environment
    :param safe_region: Safe region instance
    :param dynamics_fn: Dynamics function
    :param safe_control_fn: Verified fail safe action function
    :param punishment_fn: Reward punishment function
    :param alter_action_space: Alternative gym action space
    :param transform_action_space_fn: Action space transformation function
    :param continuous_safe_space_fn: Safe (continuous) action space function
    :param generate_wrapper_tuple: Generate tuple (wrapper action, environment reward)
    :param inv_transform_action_space_fn: Inverse action space transformation function
    """

    def __init__(self,
                 env: gym.Env,
                 safe_region,
                 dynamics_fn,
                 safe_control_fn,
                 safe_region_polytope=None,
                 punishment_fn=None,
                 alter_action_space=None,
                 generate_wrapper_tuple=False,
                 transform_action_space_fn=None,
                 continuous_safe_space_fn=None,
                 continuous_action_space_fn_polytope=None,
                 inv_transform_action_space_fn=None,
                 continuous_action_masking_mode=None,
                 generator_dim=None,
                 safe_center_obs=False):

        super().__init__(env)

        self._mask = None
        self._safe_region = safe_region
        self._safe_region_polytope = safe_region_polytope if safe_region_polytope is not None else None
        self._dynamics_fn = fetch_fn(self.env, dynamics_fn)
        self._generate_wrapper_tuple = generate_wrapper_tuple
        self._punishment_fn = fetch_fn(self.env, punishment_fn)
        self._safe_control_fn = fetch_fn(self.env, safe_control_fn)
        self._continuous_safe_space_fn = fetch_fn(self.env, continuous_safe_space_fn)
        self._continuous_safe_space_fn_polytope = fetch_fn(self.env, continuous_action_space_fn_polytope) if continuous_action_space_fn_polytope is not None else None
        self._transform_action_space_fn = fetch_fn(self.env, transform_action_space_fn)
        self.safe_center_obs = safe_center_obs
        self.previous_fail_safe = False

        if not hasattr(self.env, "action_space"):
            warnings.warn("Environment has no attribute ``action_space``")

        if alter_action_space is not None:
            self.action_space = alter_action_space
            if transform_action_space_fn is None and isinstance(self.action_space, gym.spaces.Discrete):
                warnings.warn("Set ``alter_action_space`` but no ``transform_action_space_fn``")
            elif generate_wrapper_tuple and inv_transform_action_space_fn is None:
                warnings.warn("``generate_wrapper_tuple`` but no ``inv_transform_action_space_fn``")
            else:
                self._inv_transform_action_space_fn = fetch_fn(self.env, inv_transform_action_space_fn)

        if not isinstance(self.action_space, (gym.spaces.Discrete, gym.spaces.Box)):
            raise ValueError(f"{type(self.action_space)} not supported")

        if isinstance(self.action_space, gym.spaces.Discrete):
            self.safe_center_obs = False
            # Extend action space with auxiliary action
            self._num_actions = self.action_space.n + 1
            self.action_space = gym.spaces.Discrete(self._num_actions)
        else:
            if self._continuous_safe_space_fn is None:
                raise ValueError(f"{type(self.action_space)} but no ``continuous_safe_space_fn``")

        if self.safe_center_obs:
            self.observation_space = gym.spaces.Box(
                low=np.concatenate((self.observation_space.low, self.action_space.low)),
                high=np.concatenate((self.observation_space.high, self.action_space.high))
            )
        self._continuous_action_masking_mode = continuous_action_masking_mode

        if self._continuous_action_masking_mode == ContMaskingMode.Generator:
            if generator_dim is None:
                raise ValueError('Generator dimensions must be set for using this mode')
            self.action_space = gym.spaces.Box(- np.ones(generator_dim), np.ones(generator_dim))
        
        self.safe_space = None
    
    @property
    def continuous_action_masking_mode(self):
        return self._continuous_action_masking_mode

    def reset(self, **kwargs):
        """Reset the environment."""
        obs = self.env.reset(**kwargs)
        if isinstance(self.action_space, gym.spaces.Discrete):
            self._mask = self._discrete_mask()
        obs = self.add_next_center_obs(obs)
        self.previous_fail_safe = False
        return obs

    def add_next_center_obs(self, obs):
        """Add the next center of the safe action space as observation."""
        if self.safe_center_obs:
            return np.concatenate((obs, self.next_center_obs()))
        else:
            return obs

    def _discrete_mask(self):
        """Compute the discrete action mask."""
        mask = np.zeros(self._num_actions, dtype=bool)
        for i in range(self._num_actions - 1):
            action = i if (self._transform_action_space_fn is None)\
                else self._transform_action_space_fn(i)
            if self._dynamics_fn(self.env, action) in self._safe_region:
                mask[i] = True

        if not mask.any():
            mask[-1] = True

        return mask

    def action_masks(self) -> np.ndarray:
        """Get the action mask."""
        return self._mask

    def transform_action(self, action: np.ndarray) -> np.ndarray:
        """Transform action if transformation function is given.

        Args:
            action: Action in domain [-1, 1]^n
        Returns:
            Action in environment domain.
        """
        if self._transform_action_space_fn is not None:
            return self._transform_action_space_fn(action)
        else:
            return action

    def transform_action_custom(self, action: np.ndarray, space: np.ndarray) -> np.ndarray:
        """Transform the action to a custom action space.

        Args:
            action: Action in domain [-1, 1]^n
            space: Custom action space in the format [[min_1, min_2, ..., min_n], [max_1, max_2, ..., max_n]]
        Returns:
            Action in given space.
        """
        return np.clip(((action + 1)/2) * (space[1]-space[0]) + space[0], space[0], space[1])

    def get_safe_space(self) -> Zonotope:
        """Calculate the safe space. Store the Zonotope in the physical space, but return the Zonotope in [-1, 1]."""
        safe_space = None

        try:
            safe_space = self._continuous_safe_space_fn(self.env, self._safe_region)
            if self._continuous_safe_space_fn_polytope is not None:
                safe_space = self._continuous_safe_space_fn_polytope(self.env, self._safe_region_polytope)
        except ValueError as e:
            safe_space = None
            # print("Error in safe space function: ", e)
        
        # This is very hacky and unstable ...
        self.safe_space = copy.deepcopy(safe_space)

        if safe_space is not None:
            safe_space = normalize_zonotope(safe_space, self.env.action_space.low, self.env.action_space.high)

        return safe_space

    def step(self, action):
        """Step the action in the environment."""

        # Discrete
        if isinstance(self.action_space, gym.spaces.Discrete):
            return self.step_discrete(action)

        # Note: Implemented like this, for the other approaches, self.safe_space is always None!
        if self.safe_space is None:
            return self.failsafe_action(action)

        safe_space = self.safe_space

        # Continuous action masking mode
        if self._continuous_action_masking_mode is None or self._continuous_action_masking_mode == ContMaskingMode.Interval:
            action = self.interval_masking(action, safe_space)
        elif self._continuous_action_masking_mode == ContMaskingMode.Generator:
            action = self.generator_masking(action, safe_space)
        elif self._continuous_action_masking_mode == ContMaskingMode.Ray:
            action = self.ray_masking(action, safe_space)
        elif self._continuous_action_masking_mode == ContMaskingMode.ConstrainedNormal:
            # action = self._transform_action_space_fn(action, safe_space)  # Only for _transform_action_space_zonotope_fn
            action = self._transform_action_space_fn(action)
            # pass  # Action does not need to be changed.
        else:
            raise NotImplementedError('The selected mode is not implemented. Mode selected:').format(self._continuous_action_masking_mode)

        obs, reward, done, info = self.env.step(action)
        info["masking"] = {"policy_action": action, "env_reward": reward,
                           "fail_safe_action": None, "pun_reward": None}

        if self._generate_wrapper_tuple:
            wrapper_action = self._inv_transform_action_space_fn(action) \
                if self._inv_transform_action_space_fn is not None else action
            info["wrapper_tuple"] = (np.asarray([wrapper_action]), np.asarray([reward], dtype=np.float32))

        info["masking"]["safe_space"] = safe_space
        # Note: With the current implementation, you don't get the safe_space_polytope anymore
        if self._continuous_safe_space_fn_polytope is not None:
            info["masking"]["safe_space_polytope"] = safe_space

        obs = self.add_next_center_obs(obs)
        self.previous_fail_safe = False

        return obs, reward, done, info

    def interval_masking(self, action: Union[np.ndarray, List[float], float], safe_space: np.ndarray) -> np.ndarray:
        # Scale policy action
        if not isinstance(action, np.ndarray):
            action = np.array(action)
        if not isinstance(safe_space, np.ndarray):
            raise NotImplementedError('Safe space must be a numpy array')
        if len(safe_space.shape) == 1:
            scale = (safe_space[1] - safe_space[0]) / (self.action_space.high - self.action_space.low)
            action = (scale * (action - self.action_space.low) + safe_space[0]).item()
        else:
            scale = (safe_space[:, 1] - safe_space[:, 0]) / (self.action_space.high - self.action_space.low)
            action = scale * (action - self.action_space.low) + safe_space[:, 0]
        return np.array(action)

    def generator_masking(self, action: Union[np.ndarray, List[float], float], safe_space: Zonotope) -> np.ndarray:
        if not isinstance(action, np.ndarray):
            action = np.array(action)
        return safe_space.G @ action + np.squeeze(safe_space.c)

    def ray_masking(self, action: Union[np.ndarray, List[float], float], safe_space: Zonotope) -> np.ndarray:
        center = np.squeeze(safe_space.c)
        ## This step doesn't work yet, as it is unclear what the min and max points of the zonotope are.
        # Transform action to full action space
        # transformed_action = self.transform_action(action)
        # Transform action to outer approximation of safe action space
        # min_As = np.min(safe_space, axis=0)
        # max_As = np.max(safe_space, axis=0)
        # transformed_action = self.transform_action_custom(action, np.array([min_As, max_As]))
        # TODO: Transformation from [-1,1] to physical box
        transformed_action = self._transform_action_space_fn(action)
        # ---- Determine inside ratio (distance to edge of action space) ----
        # The inside ratio gives the scaling factor with which the shifted action has to be multiplied
        # to reach the edge of the action space in the direction of the normal vector of the critical halfspace.
        # Find relevant outside borders
        # TODO: check if this are the physical limits - should be correct
        max_a = self.env.action_space.high
        min_a = self.env.action_space.low
        # Build halfspace/polytope representation of action space
        # Create a matrix with
        # [[1, 0, ..., 0], [-1, 0, ..., 0], [0, 1, ..., 0], [0, -1, ..., 0], ..., [0, 0, ..., 1], [0, 0, ..., -1]]
        # for each dimension of the action space
        As_H_action_space = np.vstack([np.eye(transformed_action.shape[0]), -np.eye(transformed_action.shape[0])])
        # Create a vector with the maximum and minimum values of the action space
        # [max_a_1, -min_a_1, max_a_2, -min_a_2, ..., max_a_n, -min_a_n]
        As_d_action_space = np.hstack([max_a, -min_a])
        # Find intersection points for action space
        alpha_action, idx_action = find_line_intersection_polytope(
            transformed_action,
            center,
            As_H_action_space,
            As_d_action_space
        )
        # Find intersection points for safe action space
        normalized_action = (transformed_action - center) / np.linalg.norm(transformed_action - center)
        boundary_point = safe_space.boundary_point(np.reshape(normalized_action, (normalized_action.shape[0], 1))).squeeze()
        alpha_safe = np.linalg.norm(boundary_point - center)
        # Calculate ratio that brings the intersection point of the action space to the edge of the safe action space
        ratio = alpha_action / alpha_safe
        # Safety ratio to avoid floating point errors
        eps = 0.9999
        # Calculate new point
        action_final = center + 1/ratio * eps * (transformed_action - center)
        # DEBUG
        # if not safe_space.contains_point(action_final[:, np.newaxis]):
        #     print('Action not in safe space')
        return action_final

    def ray_masking_polytope(self, action: Union[np.ndarray, List[float], float], safe_space: np.ndarray) -> np.ndarray:
        """This function is depreciated and should not be used for zonotope action masking."""
        # Compute halfspace/polytope representation
        As_H, As_d = compute_polytope_halfspaces(safe_space)
        # Find center of safe action space
        center = safe_space.mean(axis=0)
        # Transform action to full action space
        # transformed_action = self.transform_action(action)
        # Transform action to outer approximation of safe action space
        min_As = np.min(safe_space, axis=0)
        max_As = np.max(safe_space, axis=0)
        transformed_action = self.transform_action_custom(action, np.array([min_As, max_As]))
        # ---- Determine inside ratio (distance to edge of action space) ----
        # The inside ratio gives the scaling factor with which the shifted action has to be multiplied
        # to reach the edge of the action space in the direction of the normal vector of the critical halfspace.
        # Find relevant outside borders
        if self._transform_action_space_fn is not None:
            max_a = self._transform_action_space_fn(np.ones_like(action))
            min_a = self._transform_action_space_fn(-np.ones_like(action))
        else:
            max_a = np.ones_like(action)
            min_a = -np.ones_like(action)
        # Build halfspace/polytope representation of action space
        # Create a matrix with
        # [[1, 0, ..., 0], [-1, 0, ..., 0], [0, 1, ..., 0], [0, -1, ..., 0], ..., [0, 0, ..., 1], [0, 0, ..., -1]]
        # for each dimension of the action space
        As_H_action_space = np.vstack([np.eye(self.action_space.shape[0]), -np.eye(self.action_space.shape[0])])
        # Create a vector with the maximum and minimum values of the action space
        # [max_a_1, -min_a_1, max_a_2, -min_a_2, ..., max_a_n, -min_a_n]
        As_d_action_space = np.hstack([max_a, -min_a])
        # Find intersection points for safe action space
        alpha_safe, idx_safe = find_line_intersection_polytope(transformed_action, center, As_H, As_d)
        # Find intersection points for action space
        alpha_action, idx_action = find_line_intersection_polytope(
            transformed_action,
            center,
            As_H_action_space,
            As_d_action_space
        )
        # Calculate ratio that brings the intersection point of the action space to the edge of the safe action space
        ratio = alpha_action / alpha_safe
        # Safety ratio to avoid floating point errors
        eps = 0.9999
        # Calculate new point
        action_final = center + 1/ratio * eps * (transformed_action - center)
        # assert np.all(As_H @ action_final <= As_d)
        # Quality check that can be removed later.
        # assert self._dynamics_fn(self.env, action_final) in self._safe_region
        return action_final

    def next_center_obs(self):
        """Get the next center of the safe action space as observation."""
        try:
            next_safe_space = self._continuous_safe_space_fn(self.env, self._safe_region)
        except ValueError as e:
            print("Error in safe space function for next safe space: ", e)
            return np.zeros_like(self.action_space.low)

        if isinstance(next_safe_space, np.ndarray):
            center = next_safe_space.mean(axis=0)
        elif isinstance(next_safe_space, Zonotope):
            center = np.squeeze(next_safe_space.c)
            # if safe space is Zonotope the inv_transform_action_space_fn needs to be given the safe space as well
            return self._inv_transform_action_space_fn(center, next_safe_space)
        else:
            center = np.array([GRAVITY/K, 0])
            #raise NotImplementedError('The safe space must be a numpy array or a zonotope.')
        return self._inv_transform_action_space_fn(center)

    def step_discrete(self, action):
        """Step function if action space is discrete.

        Args:
            action: action to execute
        Returns:
            obs, reward, done, info
        """
        if self._mask[-1]:
            return self.failsafe_action(action)
        self.transform_action(action)
        # Policy action is safe
        obs, reward, done, info = self.env.step(action)
        info["masking"] = {"policy_action": action, "env_reward": reward,
                           "fail_safe_action": None, "pun_reward": None, "safe_space": self._mask[:-1]}
        # Compute next mask
        self._mask = self._discrete_mask()
        obs = self.add_next_center_obs(obs)
        self.previous_fail_safe = False
        return obs, reward, done, info

    def failsafe_action(self, action):
        """Execute failsafe action.

        Args:
            action: original action, only used for reward punishment fn
        Returns:
            obs, reward, done, info
        """
        # Fallback to verified fail-safe control
        if hasattr(self._safe_region, 'k_ctrl'):
            if self.previous_fail_safe:
                self._safe_region.k_ctrl += 1
            else:
                self._safe_region.k_ctrl = 0
        safe_action = self._safe_control_fn(self.env, self._safe_region)
        obs, reward, done, info = self.env.step(safe_action)
        info["masking"] = {"policy_action": None, "env_reward": reward,
                           "fail_safe_action": safe_action, "safe_space": None}
        if self._generate_wrapper_tuple:
            wrapper_action = self._inv_transform_action_space_fn(safe_action) \
                if self._inv_transform_action_space_fn is not None else safe_action
            info["wrapper_tuple"] = (np.asarray([wrapper_action]), np.asarray([reward], dtype=np.float32))
        # Optional reward punishment
        if self._punishment_fn is not None:
            punishment = self._punishment_fn(
              env=self.env,
              action=action,
              reward=reward,
              safe_action=safe_action
            )
            info["masking"]["pun_reward"] = punishment
            reward = punishment
        else:
            info["masking"]["pun_reward"] = None
        obs = self.add_next_center_obs(obs)
        self.previous_fail_safe = True
        return obs, reward, done, info


def normalize_halfspace(H, d):
    """Normalize halfspaces so that each row of H has unit norm.

    Args:
        H: Halfspace matrix
        d: Halfspace vector
    Returns:
        Normalized halfspace matrix and vector
    """
    H_norm = np.linalg.norm(H, axis=1)
    H_unit = H / H_norm.reshape(-1, 1)
    d_unit = d / H_norm
    return H_unit, d_unit


def find_line_intersection_polytope(p, c, H, d):
    r"""Find the first positive intersection of a line with a polygone in halfspace/polytope representation.

    x = c + \alpha * (p - c)/||p - c||
    H * x = d
    min_i \alpha_i s.t. H_i * x = d_i, \alpha_i >= 0

    Args:
        p: Point
        c: Center of polytope
        H: Halfspace matrix
        d: Halfspace vector
    Returns:
        \alpha (double): Line length to intersection point
        idx (int): Index of halfspace that is intersected first
    """
    idx = 0
    Hc = H @ c
    n = (p - c) * 1 / np.linalg.norm(p - c)
    Hn = H @ n
    alpha = (d - Hc) / Hn
    alpha[alpha < 0] = np.inf
    idx = np.argmin(alpha)
    return alpha[idx], idx
