from math import e
import os
import uuid
import socket
import gymnasium as gym
import rrls
from td3.trainer import TrainerM2TD3
from utils import env_factory, bound_factory
from tc_mdp import EvalStackedTCMDP, StackedTCMDP
from fire import Fire
from dotenv import load_dotenv
from scheduler_benchmark import evaluate_agent

load_dotenv()


def main(
    env_name: str = "Walker",
    project_name: str = "dev_stacked_tc_m2td3_001",
    nb_uncertainty_dim: int = 3,
    max_steps: int = 50_000,
    start_steps: int = 25_000,
    seed: int = 0,
    eval_freq: int = 10_000,
    track: bool = True,
    radius: float = 0.001,
    omniscient_adversary: bool = True,
    device: str = "cuda:0",
    # output_dir: str | None = None,
    **kwargs,
):
    radius_str = str(radius).replace(".", "_")
    output_dir = f"refactor_results/{project_name}/{env_name}_nb_uncertainty_dim_{nb_uncertainty_dim}_radius_{radius_str}"
    unique_id = str(uuid.uuid4())
    if output_dir is not None:
        os.makedirs(f"{output_dir}/{unique_id}", exist_ok=True)

    env = env_factory(env_name)
    eval_env = env_factory(env_name)
    params_bound = bound_factory(env_name, nb_uncertainty_dim)
    env = StackedTCMDP(
        env,
        params_bound=params_bound,
        radius=radius,
        omniscient_adversary=omniscient_adversary,
    )
    eval_env = EvalStackedTCMDP(
        eval_env,
        params_bound=params_bound,
    )
    project_name_concat = f"{project_name}"
    experiment_name = f"{env_name}_{radius_str}_{unique_id}"

    params = {
        "env_name": env_name,
        "nb_uncertainty_dim": nb_uncertainty_dim,
        "radius": radius,
        "seed": seed,
        "omniscient_adversary": omniscient_adversary,
        "machine_name": socket.gethostname(),
    }
    trainer = TrainerM2TD3(
        env=env,
        eval_env=eval_env,
        device=device,
        params=params,
        save_dir=output_dir,
        **kwargs,
    )
    trainer.train(
        experiment_name=experiment_name,
        max_steps=max_steps,
        start_steps=start_steps,
        project_name=project_name_concat,
        seed=seed,
        eval_freq=eval_freq,
        track=track,
    )

    res_df = evaluate_agent(
        agent=trainer.agent,
        env_name=env_name,
        nb_uncertainty_dim=nb_uncertainty_dim,
        eval_env=eval_env,
        nb_episodes=10,
        verbose=True,
    )

    # Save results
    res_df.to_csv(f"{output_dir}/{unique_id}/results.csv", index=False)


if __name__ == "__main__":
    Fire(main)
