import random
import numpy as np


def make_env(scenario_name):
    if scenario_name == "simple_speaker_listener":
        from pettingzoo.mpe import simple_speaker_listener_v4 as mpe_env
    elif scenario_name == "simple_spread":
        from pettingzoo.mpe import simple_spread_v3 as mpe_env
    elif scenario_name == "simple_reference":
        from pettingzoo.mpe import simple_reference_v3 as mpe_env
    else:
        raise
    env = mpe_env.parallel_env()
    return env


class MPEWrapper:

    def __init__(self, map_name, seed):
        np.bool = bool
        np.random.seed(seed)
        random.seed(seed)
        self.env_name = map_name
        self.seed = seed
        self._reset()
        self.n_enemies = -1
    
    def _reset(self):
        self._env = make_env(self.env_name)
        self.go_count = 0
        self.reset()
        
    def reset(self):
        self._observations, _ = self._env.reset(self.seed)
        self.go_count = 0
        self.total_reward = 0.0
        self.seed += 1
        return self._get_current_states()

    def get_env_info(self):
        return {
            "obs_shape": self.ob_dim,
            "state_shape": self.st_dim,
            "n_actions": self.ac_dim,
            "n_agents": self.n_agents,
            "n_enemies": self.n_enemies,
            "episode_limit": 25
        }

    def step(self, actions):
        actions = [int(action) for action in actions]
        self.go_count += 1
        actions = {agent: action for agent, action in zip(self.agent_names, actions)}
        self._observations, rewards, _, truncations, _ = self._env.step(actions)
        rewards = [rewards[agent] for agent in self.agent_names]
        dones = [truncations[agent] for agent in self.agent_names]
        reward = np.mean(rewards)
        done = np.any(dones)
        info = {}
        self.total_reward += reward
        if done:
            info["battle_won"] = self.total_reward
            self.reset()
        else:
            self._get_current_states()
        return reward, done, info

    def get_obs(self):
        obs = [self._observations[agent] for agent in self.agent_names]
        obs = np.array([np.pad(x, (0, self.ob_dim - len(x))) for x in obs])
        return obs

    def get_state(self):
        obs = [self._observations[agent] for agent in self.agent_names]
        state = np.concatenate(obs)
        return state

    def get_avail_actions(self):
        avails = np.ones((self.n_agents, self.ac_dim), dtype=bool)
        for i in self.action_spaces:
            avails[:, i:] = False
        return avails
    
    def _get_current_states(self):
        self.agent_names = list(self._observations.keys())
        self.n_agents = len(self.agent_names)
        self.ob_dim = max(int(self._observations[agent].shape[0]) for agent in self.agent_names)
        self.st_dim = sum(int(self._observations[agent].shape[0]) for agent in self.agent_names)
        self.action_spaces = [int(self._env.action_space(agent).n) for agent in self.agent_names]
        self.ac_dim = max(self.action_spaces)
        self.obs = self.get_obs()
        self.states = np.stack([self.get_state()] * self.n_agents)
        self.avails = self.get_avail_actions()
        return self.get_current_states()

    def get_current_states(self):
        return self.obs, self.states, self.avails