import json
import os
import subprocess
from datetime import datetime

from loguru import logger


def get_git_commit_hash():
    try:
        return (
            subprocess.check_output(["git", "rev-parse", "HEAD"])
            .decode("ascii")
            .strip()
        )
    except Exception as e:
        logger.error(f"Error obtaining Git commit hash: {e}")
        return "unknown"


def create_directory(path):
    try:
        os.makedirs(path, exist_ok=True)
        return path
    except OSError as e:
        logger.error(f"Error creating directory {path}: {e}")
        raise


def create_experiment_subdirectories(main_dir):
    subdirs = ["models", "outputs"]
    paths = {}
    for subdir in subdirs:
        dir_path = os.path.join(main_dir, subdir)
        paths[subdir] = create_directory(dir_path)
    return paths


def prepare_experiment_config(cfg):
    config = {
        "base_dir": "debug" if cfg.debug else cfg.experiment_dir,
        "task_model": None,
        "extra": None,
        "debug_tuning": None,
        "use_saved_buckets": None,
        "training_mode": None,
        "params": None,
    }

    if cfg.task and cfg.model:
        config["task_model"] = f"Task-{cfg.task}_PGM-{cfg.pgm}_Model-{cfg.model}"

    if cfg.debug_tuning:
        tuning_detail = "trainontest" if cfg.train_on_test_set else "standardtuning"
        config[
            "debug_tuning"
        ] = f"{cfg.model_type}_{cfg.student_layers}_{cfg.lr_scheduler}_{cfg.train_optimizer}_{tuning_detail}"

    if cfg.use_saved_buckets:
        bucket_type = "withpenalty" if cfg.add_entropy_loss else "nopenalty"
        config["use_saved_buckets"] = f"savedbuckets{bucket_type}"

    if not cfg.debug_tuning and not cfg.use_saved_buckets:
        params = [
            f"lr-{cfg.train_lr}",
            f"ep-{cfg.epochs}",
            f"bs-{cfg.batch_size}",
            f"act-{cfg.activation_function}",
            f"opt-{cfg.train_optimizer}-{cfg.test_optimizer}",
            f"lrsched-{cfg.lr_scheduler}",
            f"ip-{cfg.input_type}",
            f"emb-{cfg.embedding_type}",
        ]
        if cfg.only_test_train_on_test_set:
            params.append("onlyTestTrue")
        if cfg.train_on_test_set:
            params.append("tm-TTT")
            params.append(f"numItertot-{cfg.num_iter_train_on_test}")
            if cfg.use_batch_train_on_test:
                params.append("useBatchTot")
                params.append(f"bstot-{cfg.test_batch_size}")
            if cfg.duplicate_example_train_on_test:
                params.append("dupExTot")
            if cfg.perturb_model_train_on_test:
                params.append("perturbTot")
            params.append(f"initTot-{cfg.num_init_train_on_test}")
        if cfg.dual_network:
            params.append("dualNetwork")
            if cfg.use_pgm_optimal_dn:
                params.append("pgmOptimal")
            if cfg.copy_student_to_teacher_dn:
                params.append("copyStT")
            params.append(f"loss-{cfg.loss_dn}")
            params.append(f"totDN-{cfg.tot_train_dn}")
            params.append(f"lrDN-{cfg.teacher_lr}")
            params.append(f"optDN-{cfg.teacher_optimizer}")
        if cfg.add_entropy_loss:
            params.append(f"ent-{cfg.entropy_lambda}")
        if cfg.use_ste:
            params.append(f"ste-{cfg.use_ste}")
        if cfg.add_supervised_loss:
            params.append(f"supLoss-{cfg.add_supervised_loss}")
        if cfg.add_distance_loss_evid_ll:
            params.append(f"distLossLL-{cfg.add_distance_loss_evid_ll}")
        if cfg.same_bucket_iter:
            params.append(
                f"sameBucket-{cfg.same_bucket_iter}",
            )
        if cfg.embedding_type == "continuousEmbed":
            params.append(
                f"embdim-{cfg.embedding_size}",
            )
        elif cfg.embedding_type == "continuousSin":
            params.append(
                f"embdim-{cfg.sin_embedding_size}",
            )
        if "spn" in cfg.input_type:
            params.append(f"per-{cfg.percent_nodes_for_features}")
        config["params"] = "_".join(filter(None, params))

    # Construct a list of non-None values to form the directory name
    dir_components = [
        config[key]
        for key in [
            "task_model",
            "debug_tuning",
            "use_saved_buckets",
            "params",
        ]
        if config[key]
    ]
    config["dir_name"] = "_".join(dir_components)

    return config


def save_experiment_metadata(output_dir, cfg):
    metadata = vars(cfg)
    metadata["git_commit_hash"] = get_git_commit_hash()
    metadata["experiment_start_time"] = datetime.now().isoformat()
    # Add system and environment information if needed
    metadata_path = os.path.join(output_dir, "metadata.json")

    with open(metadata_path, "w") as f:
        json.dump(metadata, f, indent=4)


def get_output_dir(cfg, project_name):
    # Prepare the configuration for the experiment
    config = prepare_experiment_config(cfg)

    # Construct the main directory name for the experiment
    main_dir = os.path.join(config["base_dir"], config["dir_name"], project_name)
    main_dir_path = create_directory(main_dir)

    # Create subdirectories for models and outputs

    subdirectories = create_experiment_subdirectories(main_dir_path)

    # Save the experiment metadata
    save_experiment_metadata(main_dir_path, cfg)

    # Return paths to the main, model, and output directories
    return main_dir_path, subdirectories
