import os
import re
import fire
import wandb
from retry import retry
from dotenv import load_dotenv
import logging  # Import logging module

# Basic logging configuration
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

load_dotenv()


def extract_command(command_line: str) -> dict:
    """
    Extracts specified values from a command line string.

    Args:
        command_line (str): The command line string to extract values from.

    Returns:
        dict: A dictionary containing the extracted values.
    """
    # Define a pattern for each key and use a capturing group to extract the corresponding value
    patterns: dict[str, str] = {
        "env_name": r"--env_name=([^ ]+)",
        "seed": r"--seed=(\d+)",
        "project_name": r"--project_name=([^ ]+)",
        "nb_uncertainty_dim": r"--nb_uncertainty_dim=(\d+)",
        "agent_name": r"--agent_path=([^ ]+)",
    }

    # Initialize an empty dictionary to store the extracted values
    extracted_values = {}

    # Iterate over the patterns and extract each value using re.findall
    for key, pattern in patterns.items():
        match = re.findall(pattern, command_line)
        if match:
            # Since findall returns a list of matches, and we expect only one match per key, take the first match
            extracted_values[key] = match[0]

    if "env_name" not in extracted_values:
        alternative_env_name_regex = r"--env_name\s+(\S+)"
        if match := re.search(alternative_env_name_regex, command_line):
            extracted_values["env_name"] = match.group(1)
    if "project_name" not in extracted_values:
        alternative_project_name_regex = r"--project_name\s+(\S+)"
        if match := re.search(alternative_project_name_regex, command_line):
            print(match.group(1))
            extracted_values["project_name"] = match.group(1)

    if "nb_uncertainty_dim" not in extracted_values:
        alternative_nb_uncertainty_dim = r"--nb_uncertainty_dim\s+(\S+)"
        if match := re.search(alternative_nb_uncertainty_dim, command_line):
            print(match.group(1))
            extracted_values["nb_uncertainty_dim"] = match.group(1)

    if "seed" not in extracted_values:
        alternative_seed_regex = r"--seed\s+(\d+)"
        if match := re.search(alternative_seed_regex, command_line):
            extracted_values["seed"] = match.group(1)

    if "agent_name" not in extracted_values:
        alternative_agent_name_regex = r"--agent_path\s+([\w/._-]+)"
        if match := re.search(alternative_agent_name_regex, command_line):
            extracted_values["agent_name"] = match.group(1)
    extracted_values["agent_name"] = extracted_values["agent_name"].split("/")[-2]
    return extracted_values


@retry(tries=10, delay=2)
def get_wandb_run_command(entity: str, project: str, run_id: str) -> str:
    """
    Retrieves the command used to start a specific wandb run.

    Args:
    entity (str): The entity (user or organization) under which the project is hosted.
    project (str): The name of the project in wandb.
    run_id (str): The unique identifier for the run.

    Returns:
    str: The command used to start the run, if available.
    """
    api = wandb.Api()
    run = api.run(path=f"{entity}/{project}/{run_id}")
    # Assuming the command was logged as a config variable named 'command'

    metadata = run.metadata
    if metadata is None:
        raise ValueError(f"Run {run_id} has no metadata")
    run_args = metadata.get("args", [])
    args = " ".join(run_args)
    command = f"{run.metadata['program']} {args}"

    return command


def list_wandb_run_ids(entity: str, project: str) -> list[str]:
    """
    Lists all run IDs for a given entity and project in wandb.

    Args:
    entity (str): The entity (user or organization) under which the project is hosted.
    project (str): The name of the project in wandb.

    Returns:
    list[str]: A list of run IDs for the specified project.
    """
    api = wandb.Api()
    runs = api.runs(path=f"{entity}/{project}")
    run_ids = [run.id for run in runs]

    return run_ids


def download_artifacts_if_run_successful(
    entity: str, project: str, run_id: str, artifact_type: str, download_dir: str
):
    """
    Downloads the artifacts of a specific type associated with a run, if the run was successful.

    Args:
    entity (str): The entity (user or organization) under which the project is hosted.
    project (str): The name of the project in wandb.
    run_id (str): The unique identifier for the run.
    artifact_type (str): The type of artifact to download (e.g., "model", "dataset").
    download_dir (str): The directory where artifacts should be downloaded.

    Returns:
    None
    """
    api = wandb.Api()
    run = api.run(f"{entity}/{project}/{run_id}")
    # Check if the run was successful
    if run.state == "finished":
        artifacts = run.logged_artifacts()
        for artifact in artifacts:
            if artifact.type == artifact_type:
                if not os.path.exists(download_dir):
                    os.makedirs(download_dir)
                logging.info(f"Downloading artifact: {artifact.name}")
                artifact.download(root=download_dir)
                logging.info(f"Downloaded to {download_dir}")
    else:
        logging.info(f"Run {run_id} was not successful. State: {run.state}")


def main(entity: str, project: str, download_folder: str):
    """
    Main function to execute script functionalities through CLI.

    Args:
        entity (str): The entity under which the project is hosted.
        project (str): The name of the project.
        download_folder (str): The directory where artifacts should be downloaded.
    """
    run_ids = list_wandb_run_ids(entity=entity, project=project)
    for run_id in run_ids:
        command_line = get_wandb_run_command(
            entity=entity, project=project, run_id=run_id
        )
        parsed_command = extract_command(command_line=command_line)
        if "nb_uncertainty_dim" not in parsed_command:
            parsed_command["nb_uncertainty_dim"] = "0"

        if "seed" not in parsed_command:
            parsed_command["seed"] = "0"

        if "project_name" not in parsed_command:
            logging.warn(
                "Project name not found in command line, using Default, it may be hard to identify the project."
            )
            parsed_command["project_name"] = "Default"
        if "env_name" not in parsed_command:
            logging.warn(
                "Env name not found in command line, using DefaultEnv, it may be hard to identify the env."
            )
            parsed_command["env_name"] = "DefaultEnv"
        sub_folder_download = f"{download_folder}/{parsed_command['project_name']}_{parsed_command['env_name']}_{parsed_command['nb_uncertainty_dim']}_{parsed_command['seed']}_{parsed_command['agent_name']}"

        # if sub folder exists, add prime as suffix
        if os.path.exists(sub_folder_download):
            sub_folder_download += "_prime"
        download_artifacts_if_run_successful(
            entity=entity,
            project=project,
            run_id=run_id,
            artifact_type="model",
            download_dir=sub_folder_download,
        )


if __name__ == "__main__":
    fire.Fire(main)
