import argparse
import os
import re
import time
from pprint import pprint

from data.dataset_names import dataset_names
from loguru import logger


def define_arguments():
    parser = argparse.ArgumentParser(description="SS-CMPE Experiments")

    # Debugging and Environment Settings
    parser.add_argument(  # set this as True to run experiments
        "--no-debug", action="store_true", default=True, help="Enable debug mode."
    )
    parser.add_argument(
        "--no-cuda", action="store_true", default=False, help="Disable CUDA training."
    )
    parser.add_argument(
        "--no-mps",
        action="store_true",
        default=False,
        help="Disable macOS GPU training.",
    )
    parser.add_argument(
        "--data_device",
        type=str,
        default="cuda",
        choices=["cpu", "cuda"],
        help='Select the device to use: "cpu" or "gpu".',
    )

    parser.add_argument(
        "--seed", type=int, default=27, help="Random seed (default: 1)."
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="Quickly check a single pass.",
    )
    parser.add_argument(
        "--pgm",
        choices=["bn", "mn", "spn", "made"],
        default="mn",
        help="PGM to use (default: spn).",
    )
    parser.add_argument(
        "--i-bound",
        type=str,
        choices=["1", "2", "3", "4", "5"],
        default="5",
        help="I-bound to use while loading MN outputs(default: 5).",
    )
    # Model and Training Settings
    parser.add_argument(
        "--model",
        type=str,
        choices=[
            "nn",
            "transformer",
        ],
        default="nn",
        help="Model to use (default: nn).",
    )
    parser.add_argument(
        "--model-type",
        default="2",
        choices=["1", "2", "3", "4"],
        help="Model type to train (default: 2).",
    )
    parser.add_argument(
        "--latent-dim",
        type=int,
        help="Latent size for AE.",
    )
    parser.add_argument(
        "--use-ste",
        action="store_true",
        default=False,
        help="Use straight through estimator.",
    )
    parser.add_argument(
        "--use-single-model",
        action="store_true",
        default=False,
        help="Use one model per dataset.",
    )
    parser.add_argument(
        "--use-sampled-buckets",
        action="store_true",
        default=False,
        help="Use sampled buckets.",
    )
    parser.add_argument(
        "--num-sampled-buckets",
        type=int,
        default=10,
        help="Number of sampled buckets (default: 10).",
    )
    parser.add_argument(
        "--input-type",
        default="data",
        choices=[
            "data",
            "spn",
            "dataSpn",
        ],
        help="Data to use while training the model. We support using features extracted from SPN as well.",
    )
    # input can be discrete or continuous
    parser.add_argument(
        "--embedding-type",
        default="discrete",
        choices=["discrete", "continuousConst", "continuousEmbed", "continuousSin"],
        help="Embedding type (default: discrete). This is how we embed the three types of inputs (query, evidence, unobs (for MMAP)).",
    )
    parser.add_argument(
        "--sin-embedding-size",
        type=int,
        help="Size of sin embedding.",
    )
    parser.add_argument(
        "--amplitude",
        type=float,
        default=1.0,
        help="Amplitude for sin embedding (default: 1.0).",
    )
    parser.add_argument(
        "--phase",
        type=float,
        default=0.0,
        help="Phase for sin embedding (default: 0.0).",
    )
    parser.add_argument(
        "--embedding-size",
        type=int,
        default=4,
        help="Embedding size (default: 4).",
    )
    parser.add_argument(
        "--num-states",
        type=int,
        default=4,
        help="Number of states for discrete variables (default: 4).",
    )
    parser.add_argument(
        "--evaluate-training-set",
        action="store_true",
        default=False,
        help="Evaluate on training set.",
    )
    parser.add_argument(
        "--percent-nodes-for-features",
        type=float,
        default=0.5,
        help="Percent of nodes to use for SPN features (default: 0.5).",
    )
    parser.add_argument(
        "--teacher-layers",
        type=int,
        default=3,
        help="Number of layers for teacher model (default: 3).",
    )
    parser.add_argument(
        "--student-layers",
        type=int,
        default=3,
        help="Number of layers for student model (default: 3).",
    )
    parser.add_argument(
        "--no-train",
        action="store_true",
        default=False,
        help="Disable training, only testing will be done.",
    )
    parser.add_argument(
        "--no-test",
        action="store_true",
        default=False,
        help="Disable testing, only training will be done.",
    )
    parser.add_argument(
        "--no-dropout",
        action="store_true",
        default=False,
        help="Disable dropout in the model.",
    )
    parser.add_argument(
        "--no-batchnorm",
        action="store_true",
        default=True,
        help="Disable batch normalization in the model.",
    )
    parser.add_argument(
        "--dropout-rate", type=float, default=0.2, help="Dropout rate (default: 0.2)."
    )
    parser.add_argument(
        "--train-weight-decay",
        type=float,
        default=0,
        help="Weight decay for SSL (default: 0).",
    )
    parser.add_argument(
        "--test-weight-decay",
        type=float,
        default=0,
        help="Weight decay for SSL (default: 0).",
    )
    parser.add_argument(
        "--teacher-weight-decay",
        type=float,
        default=0,
        help="Weight decay for SSL of teacher (default: 0).",
    )

    parser.add_argument(
        "--activation-function",
        default="sigmoid",
        choices=["sigmoid", "hard_sigmoid"],
        help="Activation function (default: sigmoid).",
    )
    parser.add_argument(
        "--hidden-activation-function",
        default="relu",
        choices=[
            "relu",
            "leaky_relu",
        ],
        help="Activation function (default: relu).",
    )

    # Batch and Epoch Settings
    parser.add_argument(
        "--batch-size",
        type=int,
        default=512,
        help="Input batch size for training (default: 512).",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=2048,
        help="Input batch size for testing (default: 2048).",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=20,
        help="Number of epochs to train (default: 20).",
    )
    # Learning Rate and Scheduler
    parser.add_argument(
        "--train-lr", type=float, default=1e-3, help="Learning rate (default: 0.001)."
    )
    parser.add_argument(
        "--test-lr", type=float, default=1e-3, help="Learning rate (default: 0.001)."
    )
    parser.add_argument(
        "--teacher-lr",
        type=float,
        default=1e-3,
        help="Learning rate for the teacher model (default: 0.001).",
    )
    parser.add_argument(
        "--replace",
        action="store_true",
        default=False,
        help="Replace the model and outputs. If False then we skip the experiment if the run exists.",
    )
    parser.add_argument(
        "--add-gradient-clipping",
        action="store_true",
        default=False,
        help="Add gradient clipping.",
    )
    parser.add_argument(
        "--grad-clip-norm",
        type=float,
        default=5.0,
        help="Gradient clipping norm (default: 0.25).",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.90,
        help="Learning rate step gamma (default: 0.90).",
    )
    parser.add_argument(
        "--lr-scheduler",
        default="ReduceLROnPlateau",
        choices=[
            "StepLR",
            "ReduceLROnPlateau",
            "OneCycleLR",
            "MultiStepLR",
            "ExponentialLR",
            "CosineAnnealingLR",
            "CyclicLR",
            "CosineAnnealingWarmRestarts",
            "None",
        ],
        help="LR Scheduler (default: step).",
    )
    parser.add_argument(
        "--train-optimizer", default="adam", help="Choose the optimizer"
    )
    parser.add_argument("--test-optimizer", default="adam", help="Choose the optimizer")

    parser.add_argument(
        "--teacher-optimizer", default="adam", help="Choose the optimizer"
    )

    # Dataset and Paths
    parser.add_argument(
        "--dataset", default="nltcs", help="Dataset to use (default: baudio)."
    )
    parser.add_argument(
        "--pgm-model-directory",
        type=str,
        help="Location of the SPN model.",
    )
    parser.add_argument(
        "--dataset-directory",
        type=str,
    )
    parser.add_argument(
        "--experiment-dir",
        type=str,
        help="Location of the saved models and outputs.",
    )
    parser.add_argument(
        "--nn-model-path",
        type=str,
        help="Location to load the trained NN model for testing.",
    )
    parser.add_argument(
        "--use-saved-buckets",
        action="store_true",
        default=False,
        help="Use saved buckets.",
    )
    parser.add_argument(
        "--saved-buckets-directory",
        type=str,
        default="",
        help="Directory for saved buckets.",
    )
    parser.add_argument(
        "--same-bucket-iter",
        action="store_true",
        default=False,
        help="Use the same bucket for an epoch.",
    )
    # Miscellaneous
    parser.add_argument(
        "--log-interval",
        type=int,
        default=2,
        help="Batches to wait before logging training status (default: 2).",
    )
    parser.add_argument(
        "--threshold",
        type=float,
        default=0.50,
        help="Threshold for binary output conversion (default: 0.50).",
    )
    parser.add_argument(
        "--query-prob",
        type=float,
        default=0.70,
        help="Probability of query variables (default: 0.70).",
    )
    parser.add_argument(
        "--no-log-loss",
        action="store_true",
        default=False,
        help="Disable log in loss evaluation. Using log-likelihood is recommended.",
    )
    parser.add_argument(
        "--add-distance-loss-evid-ll",
        action="store_true",
        default=False,
        help="Use distance loss for evidence LL scores.",
    )
    parser.add_argument(
        "--add-evid-loss",
        action="store_true",
        default=False,
        help="Use distance loss on evidence variables.",
    )
    parser.add_argument(
        "--evid-lambda",
        type=float,
        default=0.1,
        help="Multiplier for distance loss (default: 0.1).",
    )
    parser.add_argument(
        "--add-supervised-loss",
        action="store_true",
        default=False,
        help="Use supervised loss.",
    )
    parser.add_argument(
        "--supervised-loss-lambda",
        type=float,
        default=0.1,
        help="Multiplier for supervised loss (default: 0.1).",
    )
    parser.add_argument(
        "--add-entropy-loss",
        action="store_true",
        default=False,
        help="Add entropy loss to the model.",
    )
    parser.add_argument(
        "--entropy-lambda",
        type=float,
        default=0.01,
        help="Multiplier of the entropy loss (default: 0.01).",
    )

    # dual network training
    parser.add_argument(
        "--dual-network",
        action="store_true",
        default=False,
        help="Use dual network training.",
    )
    parser.add_argument(
        "--use-pgm-optimal-dn",
        action="store_true",
        default=False,
        help="Use PGM optimal solutions to initialize the database for dual network.",
    )
    parser.add_argument(
        "--loss-dn",
        default="bce",
        choices=[
            "mse",
            "bce",
            "kl",
        ],
        help="Loss function for dual network (default: mse).",
    )
    parser.add_argument(
        "--tot-train-dn",
        type=int,
        default=100,
        help="Total number ITSELF iterations during training of dual network (default: 100).",
    )
    parser.add_argument(
        "--student-train-iter-dn",
        type=int,
        default=1,
        help="Number of iterations for training the student model (default: 1).",
    )
    parser.add_argument(
        "--copy-student-to-teacher-dn",
        action="store_true",
        default=False,
        help="Copy student parameters to teacher at the start of each epoch.",
    )
    parser.add_argument(
        "--num-train-examples",
        type=int,
        default=0,
        help="Number of training examples to use (default: 10000).",
    )
    parser.add_argument(
        "--early-stopping-patience",
        type=int,
        default=15,
        help="Early stopping patience (default: 5).",
    )
    parser.add_argument(
        "--train-on-test-set",
        action="store_true",
        default=False,
        help="Train on test set for self-supervised methods.",
    )
    parser.add_argument(
        "--only-test-train-on-test-set",
        action="store_true",
        default=False,
        help="Only test on train on test set. This can be used if you don't want to use a pre-trained model and just perform inference on a random model.",
    )
    parser.add_argument(
        "--train-on-test-set-scheduler",
        choices=["StepLR", "None"],
        default="None",
        help="LR Scheduler (default: step).",
    )
    parser.add_argument(
        "--duplicate-example-train-on-test",
        action="store_true",
        default=False,
        help="Duplicate examples for training on test set - create a batch with single example.",
    )
    parser.add_argument(
        "--perturb-model-train-on-test",
        action="store_true",
        default=False,
        help="Perturb the model.",
    )
    parser.add_argument(
        "--num-init-train-on-test",
        type=int,
        default=1,
        help="Number of initializations to try for ITSELF. More initializations can allow for the procedure to get better scores (default: 1).",
    )
    parser.add_argument(
        "--num-iter-train-on-test",
        type=int,
        default=100,
        help="Number of epochs to train on the test set (default: 5).",
    )
    parser.add_argument(
        "--use-batch-train-on-test",
        action="store_true",
        default=False,
        help="Use batch training on test set.",
    )
    parser.add_argument(
        "--num-test-examples",
        type=int,
        default=1000,
        help="Number of test examples to use (default: 0).",
    )
    parser.add_argument(
        "--task",
        type=str,
        choices=["mpe", "mmap"],
        default="mpe",
        help="Inference task type (default: mpe).",
    )
    parser.add_argument(
        "--not-save-model",
        action="store_true",
        default=False,
        help="Disable saving the current model.",
    )
    parser.add_argument(
        "--no-extra-data",
        action="store_true",
        default=True,
        help="Do not use extra data. Apart from the training set, additional sampled data can be used.",
    )
    parser.add_argument(
        "--vertical-line",
        action="store_true",
        default=False,
        help="Create a vertical line between evidence and query.",
    )
    parser.add_argument(
        "--debug-tuning",
        action="store_true",
        default=False,
        help="Find the best hyperparameters and model architecture.",
    )

    cfg = parser.parse_args()
    return process_arguments(cfg)


def process_arguments(cfg):
    if cfg.no_train:
        ensure_training_setup(cfg)
        update_model_path(cfg)
        validate_model_and_buckets(cfg)

    if cfg.use_saved_buckets:
        use_saved_buckets(cfg)

    if should_update_model_path(cfg):
        update_nn_model_path(cfg)

    if cfg.add_entropy_loss:
        log_entropy_loss_setup(cfg)

    get_embedding_model_path(
        cfg,
    )

    log_query_probability_setup(cfg)
    calculate_probabilities(cfg)

    validate_probabilities(cfg)
    add_time_to_cfg(cfg)

    project_name = generate_project_name(cfg)
    return cfg, project_name


def add_time_to_cfg(cfg):
    cfg.time = time.strftime("%Y%m%d-%H%M%S")


def get_embedding_model_path(cfg, num_hidden_layers=4):
    if "spn_sampled_dataset" in cfg.dataset_directory:
        model_path = cfg.dataset_directory.replace(
            "spn_sampled_dataset", "autoencoder_models"
        )
    else:
        model_path = cfg.dataset_directory.replace("datasets", "autoencoder_models")
    model_path = os.path.join(
        model_path, cfg.dataset, f"model_num_hidden_{num_hidden_layers}.pt"
    )
    cfg.embedding_model_path = model_path


def ensure_training_setup(cfg):
    assert (cfg.no_train == cfg.use_saved_buckets) or (
        cfg.no_train == cfg.only_test_train_on_test_set
    ), "We need buckets to test the model"
    logger.info("Not training the model")


def update_model_path(cfg):
    cfg.nn_model_path = os.path.join(
        cfg.saved_buckets_directory.replace("model_outputs", "models"), "model.pt"
    )
    if "nn_merge" in cfg.nn_model_path:
        cfg.nn_model_path = cfg.nn_model_path.replace("nn_merge/", "")
    logger.info(f"Using trained model from the directory: {cfg.nn_model_path}")


def validate_model_and_buckets(cfg):
    assert (
        cfg.nn_model_path is not None
    ), "Please provide the model directory for a trained NN model"
    assert (
        cfg.use_saved_buckets or cfg.only_test_train_on_test_set
    ), "Please provide previously saved buckets path - we use the same buckets for testing"


def use_saved_buckets(cfg):
    logger.info("Using saved buckets")
    logger.info(f"Saved buckets directory: {cfg.saved_buckets_directory}")
    logger.info("Please note using this might delete old models and outputs")
    update_task_and_dataset(cfg)


def should_update_model_path(cfg):
    return (
        "mnist" in cfg.saved_buckets_directory or "cifar" in cfg.saved_buckets_directory
    ) and cfg.no_train


def update_nn_model_path(cfg):
    pattern = r"/models_mpe_\w+/"
    replacement = "/models_mpe/"
    cfg.nn_model_path = re.sub(pattern, replacement, cfg.nn_model_path)


def log_entropy_loss_setup(cfg):
    logger.info("Using entropy loss")
    logger.info(f"Entropy lambda: {cfg.entropy_lambda}")


def log_query_probability_setup(cfg):
    logger.info(
        "Please set the value of evidence and we can calculate the query probability"
    )


def calculate_probabilities(cfg):
    if cfg.task == "mpe":
        cfg.evidence_prob = 1 - cfg.query_prob
        cfg.others_prob = 0
    elif cfg.task == "mmap":
        cfg.evidence_prob = 1 - cfg.query_prob
        cfg.others_prob, cfg.evidence_prob = (
            cfg.evidence_prob / 2,
            cfg.evidence_prob / 2,
        )
    else:
        raise ValueError("Please select a task")


def validate_probabilities(cfg):
    tolerance = 0.0001
    assert (
        abs(cfg.query_prob + cfg.evidence_prob + cfg.others_prob - 1) < tolerance
    ), f"{cfg.query_prob} + {cfg.evidence_prob} + {cfg.others_prob} != 1"


def generate_project_name(cfg):
    project_name = f"RB_Task-{cfg.task}_PGM-{cfg.pgm}_SampledBuckets-{cfg.use_sampled_buckets}_Dataset-{cfg.dataset}_Model-{cfg.model}_Type-{cfg.model_type}_NumLayers-{cfg.student_layers}-{cfg.teacher_layers}_QueryProb-{cfg.query_prob}_EvidProb-{cfg.evidence_prob}"
    if cfg.no_train:
        project_name = "only_test_" + project_name
    return project_name


def update_task_and_dataset(cfg):
    # Split the directory string into parts based on underscores
    parts = cfg.saved_buckets_directory.split("_")

    # Initialize variables to hold the probabilities
    evidence_prob_str = ""
    query_prob_str = ""

    # Iterate over the parts to find and extract the relevant values
    for part in parts:
        if part.startswith("EvidProb-"):
            evidence_prob_str = part.split("-")[-1]
        elif part.startswith("QueryProb-"):
            query_prob_str = part.split("-")[-1]

    # Convert the extracted string values to float
    cfg.evidence_prob = float(evidence_prob_str) if evidence_prob_str else 0.0
    cfg.query_prob = float(query_prob_str) if query_prob_str else 0.0
    update_task_from_directory(cfg)
    update_dataset_from_directory(cfg)


def update_task_from_directory(cfg):
    if "mmap" in cfg.saved_buckets_directory.lower():
        cfg.task = "mmap"
    elif "mpe" in cfg.saved_buckets_directory.lower():
        cfg.task = "mpe"


def update_dataset_from_directory(cfg):
    for each_dataset in dataset_names:
        if each_dataset in cfg.saved_buckets_directory.lower():
            cfg.dataset = each_dataset
            if (
                "mnist" in each_dataset
                and "emnist" in cfg.saved_buckets_directory.lower()
            ):
                continue
            logger.info(f"Using dataset: {cfg.dataset}")
            break
