import os
from multiprocessing import dummy

import gym
import numpy as np
import torch
from gym.spaces.box import Box
from gym.wrappers import Monitor
from pettingzoo.sisl import pursuit_v4
from ray import tune
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize as VecNormalize_
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
from wrappers import Monitor, TimeLimit

from suPER.experiments.env import (
    env_creator_adversarial_pursuit,
    env_creator_battle,
    env_creator_pursuit,
)

tune.register_env("pursuit", lambda config: env_creator_pursuit(config))
tune.register_env("battle", lambda config: env_creator_battle(config))
tune.register_env("adversarial_pursuit", lambda config: env_creator_adversarial_pursuit(config))


class wrap_pursuit(gym.Wrapper):
    def __init__(self, env, **kwargs):
        env.reward_range = (-np.inf, np.inf)
        env.spec = None
        super().__init__(env, **kwargs)
        self.observation_space = gym.spaces.Tuple([gym.spaces.Box(low=0.0, high=30.0, shape=(147,), dtype=np.float32) for _ in range(8)])
        self.action_space = gym.spaces.Tuple([gym.spaces.Discrete(5) for _ in range(8)])
        self.n_agents = 8

    def reset(self):
        obs = self.env.reset()
        return [obs[f"pursuer_{i}"] for i in range(8)]

    def step(self, actions):
        action_dict = {f"pursuer_{i}": actions[i] for i in range(8)}
        obs, reward, done, info = self.env.step(action_dict)
        return [obs[f"pursuer_{i}"] for i in range(8)], [reward[f"pursuer_{i}"] for i in range(8)], [done[f"pursuer_{i}"] for i in range(8)], {}


class wrap_battle(gym.Wrapper):
    def __init__(self, env, **kwargs):
        env.reward_range = (-np.inf, np.inf)
        env.spec = None
        super().__init__(env, **kwargs)
        self.observation_space = gym.spaces.Tuple([gym.spaces.Box(low=0.0, high=2.0, shape=(845,), dtype=np.float32) for _ in range(6)])
        self.action_space = gym.spaces.Tuple([gym.spaces.Discrete(21) for _ in range(6)])
        self.n_agents = 6

    def reset(self):
        obs = self.env.reset()
        return [obs[f"blue_{i}"] for i in range(6)]

    def step(self, actions):
        action_dict = {f"blue_{i}": actions[i] for i in range(6)}
        obs, reward, done, info = self.env.step(action_dict)
        return [obs[f"blue_{i}"] for i in range(6)], [reward[f"blue_{i}"] for i in range(6)], [done[f"blue_{i}"] for i in range(6)], {}


class wrap_advp(gym.Wrapper):
    def __init__(self, env, **kwargs):
        env.reward_range = (-np.inf, np.inf)
        env.spec = None
        super().__init__(env, **kwargs)
        self.observation_space = gym.spaces.Tuple([gym.spaces.Box(low=0.0, high=2.0, shape=(500,), dtype=np.float32) for _ in range(8)])
        self.action_space = gym.spaces.Tuple([gym.spaces.Discrete(13) for _ in range(8)])
        self.n_agents = 8

    def reset(self):
        obs = self.env.reset()
        return [obs[f"prey_{i}"] for i in range(8)]

    def step(self, actions):
        action_dict = {f"prey_{i}": actions[i] for i in range(8)}
        obs, reward, done, info = self.env.step(action_dict)
        return [obs[f"prey_{i}"] for i in range(8)], [reward[f"prey_{i}"] for i in range(8)], [done[f"prey_{i}"] for i in range(8)], {}


class MADummyVecEnv(DummyVecEnv):
    def __init__(self, env_fns):
        super().__init__(env_fns)
        agents = len(self.observation_space)
        # change this because we want >1 reward
        self.buf_rews = np.zeros((self.num_envs, agents), dtype=np.float32)


def make_env(env_id, seed, rank, time_limit, wrappers, monitor_dir):
    def _thunk():
        if env_id == "battle":
            # For Battle:
            # "wrap_pretrained_agents" makes the env use pretrained agents, but purely internal to the env. Should also work for AdvPursuit.
            # "black death" is needed here due to how we wrap it for SEAC. Shouldn't be needed for other envs.
            env_config = {
                "map_size": 18,
                "flatten_obs": True,
                "actions_are_logits": False,
                "wrap_pretrained_agents": True,
                "black_death": True,
                "seed": seed,
            }
            env = env_creator_battle(env_config)
            env = wrap_battle(env)

        elif env_id == "adversarial_pursuit":
            # For AdvPursuit:
            env_config = {
                "map_size": 18,
                "flatten_obs": True,
                "actions_are_logits": False,
                "wrap_pretrained_agents": True,
                "black_death": True,
                "seed": seed,
            }
            env = env_creator_adversarial_pursuit(env_config)
            env = wrap_advp(env)

        elif env_id == "pursuit":
            # For Pursuit:
            env_config = {
                "num_agents": 8,
                "n_evaders": 30,
                "shared_reward": False,
                "flatten_obs": True,
                "seed": seed,
            }
            env = env_creator_pursuit(env_config)
            env = wrap_pursuit(env)

        if time_limit:
            env = TimeLimit(env, time_limit)
        for wrapper in wrappers:
            env = wrapper(env)

        return env

    return _thunk


def make_vec_envs(env_name, seed, dummy_vecenv, parallel, time_limit, wrappers, device, monitor_dir=None):
    dummy_vecenv = True
    envs = [make_env(env_name, seed, i, time_limit, wrappers, monitor_dir) for i in range(parallel)]

    if dummy_vecenv or len(envs) == 1 or monitor_dir:
        envs = MADummyVecEnv(envs)
    else:
        envs = SubprocVecEnv(envs, start_method="fork")

    envs = VecPyTorch(envs, device)
    return envs


class VecPyTorch(VecEnvWrapper):
    def __init__(self, venv, device):
        """Return only every `skip`-th frame"""
        super(VecPyTorch, self).__init__(venv)
        self.device = device
        # TODO: Fix data types

    def reset(self):
        obs = self.venv.reset()
        return [torch.from_numpy(o).to(self.device) for o in obs]
        return obs

    def step_async(self, actions):
        actions = [a.squeeze().cpu().numpy() for a in actions]
        actions = list(zip(*actions))
        return self.venv.step_async(actions)

    def step_wait(self):
        obs, rew, done, info = self.venv.step_wait()
        return (
            [torch.from_numpy(o).float().to(self.device) for o in obs],
            torch.from_numpy(rew).float().to(self.device),
            torch.from_numpy(done).float().to(self.device),
            info,
        )
