import numpy as np
from gymnasium import ObservationWrapper, spaces
from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX


class MinigridObservationWrapper(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)

        obs_image_space_shape = self.observation_space.spaces["image"].shape

        self.observation_space = spaces.Dict(
            {
                "obs": spaces.Box(
                    low=0,
                    high=255,
                    shape=(obs_image_space_shape[2], obs_image_space_shape[0], obs_image_space_shape[1]),
                    dtype="uint8",
                ),
                "state": spaces.Box(
                    low=0,
                    high=255,
                    shape=(3, self.env.width, self.env.height),
                    dtype="uint8",
                ),
            }
        )

    def observation(self, obs):
        obs_image = obs["image"]

        env = self.unwrapped
        full_view = env.grid.encode()
        full_view[env.agent_pos[0]][env.agent_pos[1]] = np.array([OBJECT_TO_IDX["agent"], COLOR_TO_IDX["red"], env.agent_dir])
        state_image = full_view

        return {"obs": obs_image.transpose(2, 0, 1), "state": state_image.transpose(2, 0, 1)}
