from typing import Unpack
from m2td3.trainer import Trainer
from m2td3.algo import M2TD3Config
from fire import Fire
from dotenv import load_dotenv

load_dotenv()


def main(
    experiment_name: str,
    env_name: str,
    nb_uncertainty_dim: int,
    seed: int,
    output_dir: str,
    start_steps: int = 1e5,
    evaluate_qvalue_interval: int = 1e4,
    logger_interval: int = 1e5,
    evaluate_interval: int = 1e5,
    max_steps: int = 2e6,
    device: str = "cuda:0",
    oracle_parameters_agent: bool = False,
    **kwargs: Unpack[M2TD3Config],
):
    """Run Trainer

    Parameters
    ----------
    experiment_name : str
        experiment name

    """
    trainer = Trainer(
        experiment_name=experiment_name,
        env_name=env_name,
        nb_uncertainty_dim=nb_uncertainty_dim,
        device=device,
        seed=seed,
        output_dir=output_dir,
        start_steps=start_steps,
        evaluate_qvalue_interval=evaluate_qvalue_interval,
        logger_interval=logger_interval,
        evaluate_interval=evaluate_interval,
        max_steps=max_steps,
        oracle_parameters_agent=oracle_parameters_agent,
        **kwargs,
    )
    trainer.main()


if __name__ == "__main__":
    Fire(main)
