from typing import Dict, List, Tuple
from pathlib import Path
import json
from fire import Fire
import rrls
from utils import bound_factory, AgentInferenceBuilder
from dotenv import load_dotenv
from evaluation import Agent, evaluate
from tc_mdp import EvalStaticOracle, EvalStaticStacked
from rrls.evaluate import (
    EVALUATION_ROBUST_ANT_2D,
    EVALUATION_ROBUST_ANT_3D,
    EVALUATION_ROBUST_HOPPER_2D,
    EVALUATION_ROBUST_HOPPER_3D,
    EVALUATION_ROBUST_WALKER_2D,
    EVALUATION_ROBUST_WALKER_3D,
    EVALUATION_ROBUST_HALF_CHEETAH_2D,
    EVALUATION_ROBUST_HALF_CHEETAH_3D,
    EVALUATION_ROBUST_HUMANOID_STANDUP_2D,
    EVALUATION_ROBUST_HUMANOID_STANDUP_3D,
)
from tqdm import tqdm
from statistics import mean
from glob import glob

load_dotenv()
EVAL_SET = {
    2: {
        "Ant": EVALUATION_ROBUST_ANT_2D,
        "Hopper": EVALUATION_ROBUST_HOPPER_2D,
        "Walker": EVALUATION_ROBUST_WALKER_2D,
        "HalfCheetah": EVALUATION_ROBUST_HALF_CHEETAH_2D,
        "Humanoid": EVALUATION_ROBUST_HUMANOID_STANDUP_2D,
    },
    3: {
        "Ant": EVALUATION_ROBUST_ANT_3D,
        "Hopper": EVALUATION_ROBUST_HOPPER_3D,
        "Walker": EVALUATION_ROBUST_WALKER_3D,
        "HalfCheetah": EVALUATION_ROBUST_HALF_CHEETAH_3D,
        "HumanoidStandup": EVALUATION_ROBUST_HUMANOID_STANDUP_3D,
    },
}
ENV_NAMES = {"Ant", "Hopper", "Walker", "HalfCheetah", "HumanoidStandup"}


def get_agent(agent_path, agent_type, env, nb_uncertainty_dim, device):
    agent_builder = AgentInferenceBuilder(
        env=env, nb_dim=nb_uncertainty_dim, device=device
    )
    agent: Agent = (
        agent_builder.add_actor_path(path=agent_path)
        .add_device(device)
        .add_agent_type(agent_type)
        .build()
    )

    return agent


def grid_performance(
    agent_path: str,
    env_name: str,
    nb_uncertainty_dim: int,
    num_episodes: int = 1,
    device: str = "cuda:0",
    agent_type: str = "td3",
    save_folder: str = None,
    prefix_name: str = "",
):
    env_set = EVAL_SET[nb_uncertainty_dim][env_name]
    bounds: Dict[str, List[float]] = bound_factory(
        env_name=env_name, nb_dim=nb_uncertainty_dim
    )
    episodes_infos = []
    agent = None

    for env in env_set:
        # check if the json file doesn't exist
        if save_folder is not None:
            if Path(f"{save_folder}/{prefix_name}grid_eval_info.json").exists():
                print(
                    f"File {save_folder}/{prefix_name}grid_eval_info.json already exists"
                )
                return
        mdp_params = env.get_params()
        env = wrap_env(env, agent_path, bounds)

        if agent is None:
            agent = get_agent(
                agent_path=agent_path,
                agent_type=agent_type,
                env=env,
                nb_uncertainty_dim=nb_uncertainty_dim,
                device=device,
            )
        filtred_mdp_params = {k: v for k, v in mdp_params.items() if k in bounds}
        reward: List[float] = evaluate(
            env=env, agent=agent, seed=0, num_episodes=num_episodes
        )

        mean_reward: float = mean(reward)
        episodes_infos.append((filtred_mdp_params, mean_reward))

    episodes_infos: List[Tuple[Dict[str, float], float]] = sorted(
        episodes_infos, key=lambda x: x[1]
    )

    if save_folder is not None:
        with open(f"{save_folder}/{prefix_name}grid_eval_info.json", "w") as f:
            json.dump(episodes_infos, f)


def grid_performance_multiple_envs(
    agent_glob_path: str,
    device: str,
    num_episodes: int = 1,
    nb_uncertainty_dim: int = 3,
    agent_type: str = "m2td3",
):
    glob_path = sorted(glob(agent_glob_path))
    for path in tqdm(glob_path):
        for env_name in ENV_NAMES:
            if env_name in path:
                save_folder = str(Path(path).parent)
                file_name = Path(path).name
                grid_performance(
                    agent_path=path,
                    env_name=env_name,
                    nb_uncertainty_dim=nb_uncertainty_dim,
                    num_episodes=num_episodes,
                    device=device,
                    agent_type=agent_type,
                    save_folder=save_folder,
                    prefix_name=f"{file_name}_num_ep_{num_episodes}_",
                )


def wrap_env(env, agent_path, bounds):
    if "oracle" in agent_path:
        env = EvalStaticOracle(env, bounds)
    if "stacked" in agent_path:
        env = EvalStaticStacked(env, bounds)

    return env


if __name__ == "__main__":
    Fire(grid_performance_multiple_envs)
