import re
from typing import Dict, List, Tuple
import os
from utils import env_factory, bound_factory, scheduler_factory, AgentInferenceBuilder
from evaluation import Agent, evaluate_with_scheduler, evaluate
from fire import Fire
from dotenv import load_dotenv
from tqdm import tqdm
import pandas as pd
from tc_mdp import EvalOracleTCMDP, EvalStackedTCMDP

import warnings

warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    message="The behavior of DataFrame concatenation with empty or all-NA entries is deprecated.",
)

load_dotenv()


EVAL_SEED = 42


def main(
    all_agents_folder: str,
    output_folder: str,
    nb_episodes: int = 1,
    device: str = "cuda:0",
    verbose: bool = False,
):
    all_results = pd.DataFrame(
        columns=[
            "algo",
            "env_name",
            "uncertainty_dim",
            "seed",
            "scheduler_type",
            "rewards_mean",
            "rewards_std",
            "rewards",
        ]
    )
    for algo in tqdm(os.listdir(all_agents_folder), desc="Algo", leave=False):
        if ("vanilla" in algo) & ("tc" not in algo):
            nb_uncertainty_dim = 0
        else:
            nb_uncertainty_dim = 3
        agent_type = (
            "td3"
            if "tc-" in algo or "tc_" in algo
            else ("m2td3" if "m2td3" in algo else "td3")
        )
        algo_folder = os.path.join(all_agents_folder, algo)
        env_names = ["Ant", "HalfCheetah", "Hopper", "HumanoidStandup", "Walker"]
        for env_name in tqdm(env_names, desc="Environment", leave=False):
            agents_folder = os.path.join(algo_folder, env_name, str(nb_uncertainty_dim))
            for i, agent_path in tqdm(
                enumerate(os.listdir(agents_folder)), desc="Seed", leave=False
            ):
                agent_path = os.path.join(agents_folder, agent_path)
                eval_env = env_factory(env_name=env_name)
                if "oracle" in agent_path:
                    params_bound: Dict[str, List[float]] = bound_factory(
                        env_name=env_name, nb_dim=nb_uncertainty_dim
                    )
                    eval_env = EvalOracleTCMDP(eval_env, params_bound)
                if "stacked" in agent_path:
                    params_bound: Dict[str, List[float]] = bound_factory(
                        env_name=env_name, nb_dim=nb_uncertainty_dim
                    )
                    eval_env = EvalStackedTCMDP(eval_env, params_bound)
                agent = get_agent(
                    agent_path, agent_type, eval_env, nb_uncertainty_dim, device
                )
                results = evaluate_agent(
                    agent, eval_env, env_name, nb_uncertainty_dim, nb_episodes, verbose
                )

                results["algo"] = algo
                results["env_name"] = env_name
                results["uncertainty_dim"] = nb_uncertainty_dim
                results["seed"] = i
                all_results = pd.concat([all_results, results], ignore_index=True)

    os.makedirs(output_folder, exist_ok=True)
    all_results.to_csv(os.path.join(output_folder, "results.csv"), index=False)


def mean_and_std(rewards: List[float]) -> Tuple[float, float]:
    mean = sum(rewards) / len(rewards)
    std = (sum([(reward - mean) ** 2 for reward in rewards]) / len(rewards)) ** 0.5
    return mean, std


def get_agent(agent_path, agent_type, env, nb_uncertainty_dim, device):
    agent_builder = AgentInferenceBuilder(
        env=env, nb_dim=nb_uncertainty_dim, device=device
    )
    agent: Agent = (
        agent_builder.add_actor_path(path=agent_path)
        .add_device(device)
        .add_agent_type(agent_type)
        .build()
    )

    return agent


def evaluate_agent(
    agent,
    eval_env,
    env_name,
    nb_uncertainty_dim,
    nb_episodes,
    verbose,
):
    """
    Given an agent, evaluate it in an environment with different scheduler types and return the results as a DataFrame.

    Args:
        agent (Agent): The agent to evaluate.
        env_name (str): Name of the environment.
        nb_uncertainty_dim (int): Number of uncertainty dimensions.
        nb_episodes (int): Number of episodes to run for each scheduler type.
        verbose (bool): Whether to print verbose output.
    """
    result_df = pd.DataFrame(
        columns=["scheduler_type", "rewards_mean", "rewards_std", "rewards"]
    )
    params_bound: Dict[str, List[float]] = bound_factory(
        env_name=env_name, nb_dim=nb_uncertainty_dim
    )

    rewards = evaluate(
        env=eval_env, agent=agent, seed=EVAL_SEED, num_episodes=nb_episodes
    )
    mean_reward, std_reward = mean_and_std(rewards)
    result_df = pd.DataFrame(
        [
            {
                "scheduler_type": "vanilla",
                "rewards_mean": mean_reward,
                "rewards_std": std_reward,
                "rewards": rewards,
            }
        ]
    )
    if verbose:
        print(f"vanilla rewards: {rewards}")

    for scheduler_type in ["linear", "exponential", "logarithmic"]:
        scheduler = scheduler_factory(
            scheduler_type=scheduler_type,
            params_bound=params_bound,
            max_step=1000,
            seed=EVAL_SEED,
        )
        rewards = evaluate_with_scheduler(
            env=eval_env,
            agent=agent,
            scheduler=scheduler,
            seed=EVAL_SEED,
            num_episodes=nb_episodes,
        )
        mean_reward, std_reward = mean_and_std(rewards)
        new_row = pd.DataFrame(
            [
                {
                    "scheduler_type": scheduler_type,
                    "rewards_mean": mean_reward,
                    "rewards_std": std_reward,
                    "rewards": rewards,
                }
            ]
        )
        result_df = pd.concat([result_df, new_row], ignore_index=True)
        if verbose:
            print(f"{scheduler_type} rewards: {rewards}")
    scheduler = scheduler_factory(
        scheduler_type="random", params_bound=params_bound, radius=0.1, seed=EVAL_SEED
    )
    rewards = evaluate_with_scheduler(
        env=eval_env,
        agent=agent,
        scheduler=scheduler,
        seed=EVAL_SEED,
        num_episodes=nb_episodes,
    )
    mean_reward, std_reward = mean_and_std(rewards)
    new_row = pd.DataFrame(
        [
            {
                "scheduler_type": "random",
                "rewards_mean": mean_reward,
                "rewards_std": std_reward,
                "rewards": rewards,
            }
        ]
    )
    result_df = pd.concat([result_df, new_row], ignore_index=True)
    if verbose:
        print(f"random rewards: {rewards}")
    scheduler = scheduler_factory(
        scheduler_type="cosine",
        params_bound=params_bound,
        radius=0.0015,
        seed=EVAL_SEED,
    )
    rewards = evaluate_with_scheduler(
        env=eval_env,
        agent=agent,
        scheduler=scheduler,
        seed=EVAL_SEED,
        num_episodes=nb_episodes,
    )
    mean_reward, std_reward = mean_and_std(rewards)
    new_row = pd.DataFrame(
        [
            {
                "scheduler_type": "cosine",
                "rewards_mean": mean_reward,
                "rewards_std": std_reward,
                "rewards": rewards,
            }
        ]
    )
    result_df = pd.concat([result_df, new_row], ignore_index=True)
    if verbose:
        print(f"cosine rewards: {rewards}")

    return result_df


if __name__ == "__main__":
    Fire(main)
