from __future__ import print_function

import os
import sys

sys.path.append(
    # Add the path of the anympe directory here
)  # Adds the parent directory to the system path
import lightning as L
import numpy as np
import torch
from get_model import init_model_and_optimizer
from loguru import logger
from nn_scripts import train, train_dual_network, validate
from project_utils.data_utils import (
    get_dataloaders,
    get_mpe_solutions,
    get_num_var_in_buckets,
    init_train_test_data,
    load_data,
)
from project_utils.experiment_utils import check_previous_runs, test_assertions
from project_utils.logging_utils import init_logger_and_wandb, load_args_from_yaml
from project_utils.model_utils import (
    get_num_features,
    get_pgm_loss,
    init_embedding_sizes,
)
from runner_helper import *
from utils_folder.arguments import define_arguments

# Import modules from the anympe folder
torch.set_default_dtype(torch.float64)


sys.path.append(
    # Add the path of the anympe directory here
)  # Adds the parent directory to the system path


@logger.catch
def runner(cfg, project_name):
    cfg.no_extra_data = True
    init_debug(cfg)
    dataset_name = cfg.dataset
    data, extra_data, buckets = load_data(cfg, dataset_name)
    train_data, test_data, val_data = data["train"], data["test"], data["valid"]
    val_buckets, test_buckets = buckets["valid"], buckets["test"]
    test_data, test_buckets, train_data, val_data, val_buckets = init_train_test_data(
        cfg, test_data, test_buckets, val_data, val_buckets, train_data
    )
    num_nodes_in_graph = train_data.shape[1]
    init_embedding_sizes(cfg, num_nodes_in_graph)
    init_directories(cfg, project_name)
    # init best_model_info with None

    # Check if this experiment has already been run
    best_model_info, train_data, test_data, val_data = check_previous_runs(
        cfg, train_data, test_data, val_data
    )
    # torch.set_float_matmul_precision("medium")
    device, use_cuda, use_mps = init_logger_and_wandb(project_name, cfg)
    fabric = L.Fabric(
        accelerator=device,
        devices=1,
        # precision="32-true",
        precision="64",
        # callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)],
    )
    fabric.launch()
    cfg.device = device
    torch.manual_seed(cfg.seed)
    if cfg.use_sampled_buckets:
        # get the unique buckets
        unique_buckets = {
            key: torch.from_numpy(np.unique(test_buckets[key], axis=0))
            .bool()
            .to(device=cfg.device)
            for key in test_buckets
        }
        if unique_buckets["unobs"].shape[0] == 1:
            unique_buckets["unobs"] = unique_buckets["unobs"].repeat(
                len(unique_buckets["evid"]), 1
            )
        logger.info(f"Number of unique buckets: {len(unique_buckets['evid'])}")
        # cfg.batch_size = cfg.batch_size // cfg.num_sampled_buckets
    else:
        unique_buckets = None
    if not cfg.no_extra_data:
        train_data = torch.cat((train_data, extra_data), dim=0)
        logger.info("We are adding extra sampled data")
    logger.info(f"Train data shape: {train_data.shape}")
    # Need to get depth of features from the spn model - depth_features
    num_outputs = train_data.shape[1]
    torch_pgm, library_pgm = get_pgm_loss(cfg, device, num_outputs)
    num_data_features, num_pgm_feature = get_num_features(cfg, train_data, torch_pgm)

    num_variables_in_buckets = get_num_var_in_buckets(cfg, num_outputs)

    train_loader, test_loader, val_loader = get_dataloaders(
        cfg,
        train_data,
        test_data,
        val_data,
        val_buckets,
        test_buckets,
        use_cuda,
        unique_buckets,
        torch_pgm,
        num_variables_in_buckets,
    )
    # If only features are used - then n_features
    # We also use spn_features - then n_features + spn_features - get this from the spn model
    num_query_variables = num_variables_in_buckets[1]
    num_layers = cfg.student_layers
    predefined_layers = [128 * (2**i) for i in range(num_layers)]
    print("Number of layers:", num_layers)
    logger.info(f"Hidden layers: {predefined_layers}")
    models_and_optimizers = init_model_and_optimizer(
        cfg,
        library_pgm,
        device,
        fabric,
        num_data_features,
        num_pgm_feature,
        num_outputs,
        num_query_variables=num_query_variables,
        run_type="train",
    )
    # get saved mpe solutions
    mpe_solutions = get_mpe_solutions(cfg)

    if mpe_solutions is None and cfg.pgm != "spn":
        logger.info("No MPE solutions found. The baseline did not finish running")
        mpe_solutions = {}
        mpe_solutions["test_mpe_output"] = np.zeros_like(test_data)
        mpe_solutions["test_root_ll_pgm"] = np.zeros(test_data.shape[0])

    if cfg.dual_network:
        # student model is used for testing
        model, optimizer, teacher_model, teacher_optimizer = models_and_optimizers
        logger.info(f"Student Model: {model}")
        logger.info(f"Teacher Model: {teacher_model}")
        database_for_student = torch.zeros_like(
            train_loader.dataset.data, device=device
        )
        bucket_database_for_student = {
            key: torch.zeros_like(train_loader.dataset.data, device=device, dtype=bool)
            for key in test_buckets
        }
        # initialize the best loss database with very high values
        best_loss_database = (
            torch.ones(train_loader.dataset.data.shape[0], device=device) * 1e10
        )
    else:
        model, optimizer = models_and_optimizers
        logger.info(f"Model: {model}")

    lr_scheduler = select_lr_scheduler(
        cfg,
        cfg.lr_scheduler,
        optimizer,
        train_loader,
        verbose=True,
    )
    train_loader = fabric.setup_dataloaders(train_loader)
    test_loader = fabric.setup_dataloaders(test_loader)
    val_loader = fabric.setup_dataloaders(val_loader)
    best_loss = float("inf")
    counter = 0
    patience = 5  # Number of epochs to wait for the validation loss to improve

    # Train the model
    all_train_losses = []
    all_val_losses = []
    # Select which method to use for validation

    if not cfg.no_train:
        # Initialize variables
        best_val_loss = float("inf")
        logger.info("First validation")
        # This script only runs when we are training the model
        (
            best_loss,
            val_loss,
            counter,
            all_unprocessed_data,
            all_nn_outputs,
            _,
            _,
        ) = validate(cfg, model, torch_pgm, device, val_loader, best_loss, counter)
        all_val_losses.append(val_loss)
        logger.info("Training the model")
        for epoch in range(1, cfg.epochs + 1):
            if cfg.dual_network:
                train_loss = train_dual_network(
                    cfg,
                    model,
                    teacher_model,
                    torch_pgm,
                    library_pgm,
                    device,
                    fabric,
                    train_loader,
                    optimizer,
                    teacher_optimizer,
                    database_for_student,
                    bucket_database_for_student,
                    best_loss_database,
                    epoch,
                    schedular=lr_scheduler,
                )
            else:
                train_loss = train(
                    cfg,
                    model,
                    torch_pgm,
                    device,
                    fabric,
                    train_loader,
                    optimizer,
                    epoch,
                    schedular=lr_scheduler,
                )

            all_train_losses.append(train_loss)
            # Validate every 5 epochs
            if epoch % 1 == 0:
                (
                    best_loss,
                    val_loss,
                    counter,
                    all_unprocessed_data,
                    all_nn_outputs,
                    _,
                    _,
                ) = validate(
                    cfg, model, torch_pgm, device, val_loader, best_loss, counter
                )
                all_val_losses.append(val_loss)
            # lr_scheduler.step(train_loss)
            if cfg.epochs > 2:
                if lr_scheduler is None:
                    pass
                elif not isinstance(
                    lr_scheduler,
                    (
                        torch.optim.lr_scheduler.CyclicLR,
                        torch.optim.lr_scheduler.OneCycleLR,
                    ),
                ):
                    # These schedulers need to be called after each epoch
                    if isinstance(
                        lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
                    ):
                        lr_scheduler.step(val_loss)
                    else:
                        lr_scheduler.step()

            print("Learning rate:", optimizer.param_groups[0]["lr"])
            # Save model if it has lower validation loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_info = {
                    "epoch": epoch,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                }
            if counter >= patience:
                print(
                    "Validation loss hasn't improved for {} epochs, stopping training...".format(
                        patience
                    )
                )
                break
        logger.info("Training completed!")
        logger.info("Saving model...")
        logger.info(f"Best model saved at {cfg.model_dir}/model.pt")
        torch.save(best_model_info, f"{cfg.model_dir}/model.pt")
        best_model_path = f"{cfg.model_dir}/model.pt"
        cfg.model_path = best_model_path

    # Save Actual and Adversarial Images
    # Save the model
    if not cfg.not_save_model and not cfg.no_test:
        test_assertions(cfg, best_model_info)
        if best_model_info["epoch"] is None:
            # If the model has not been trained - take a random model
            best_model_info = {
                "epoch": 0,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
            }
        test_and_process_outputs(
            cfg,
            device,
            fabric,
            torch_pgm,
            cfg.model_dir,
            cfg.model_outputs_dir,
            train_loader,
            test_loader,
            val_loader,
            mpe_solutions,
            model,
            library_pgm,
            optimizer,
            best_loss,
            counter,
            all_train_losses,
            all_val_losses,
            best_model_info,
            num_data_features,
            num_pgm_feature,
            num_outputs,
            num_query_variables,
        )


if __name__ == "__main__":
    # Training settings
    cfg, project_name = define_arguments()
    runner(cfg, project_name)

    # experiments are finished
