from __future__ import print_function

import argparse
import os
import random
import sys
import time
from pprint import pprint

import numpy as np
import torch
import wandb
from deeprob.spn.algorithms.inference import log_likelihood, mpe
from deeprob.spn.structure.io import load_spn_json
from inference_script import hill_climbing_search
from loguru import logger
from matplotlib import pyplot as plt
from utils import init_logger_and_wandb

sys.path.append(
    # Add the path of the anympe directory here
)  # Adds the parent directory to the system path
from get_spn_class_log import SPNModel


def get_hyper_param(path):
    splits = path.split("/")
    hyper_params_str = splits[-3]
    hyper_params_list = hyper_params_str.split("_")
    hyper_params_dict = {}
    for each_hyper_param in hyper_params_list:
        if "-" in each_hyper_param:
            this_split = each_hyper_param.split("-")
            if len(this_split) == 2:
                key, value = this_split
            elif len(this_split) == 3:
                key, value, _ = this_split
            hyper_params_dict[key] = value

    data_info = splits[-2]
    data_info_list = data_info.split("_")
    for each_data_info in data_info_list:
        if "-" in each_data_info:
            this_split = each_data_info.split("-")
            if len(this_split) == 2:
                key, value = this_split
            elif len(this_split) == 3:
                key, value, _ = this_split
            hyper_params_dict[key] = value
    return hyper_params_dict


@logger.catch()
def evaluate_random_sampling(args, project_name, run_name):
    if "test_outputs.npz" not in os.listdir(args.nn_model_output_directory):
        raise FileNotFoundError
    debug = args.debug
    hyper_params_dict = get_hyper_param(args.nn_model_output_directory)
    nn_output_location = os.path.join(
        args.nn_model_output_directory,
        "test_outputs.npz",
    )

    if args.debug:
        output_dir = os.path.join(
            "debug", "models", "hill_climbing_search", project_name
        )
    else:
        # Extracting the directory path
        base_experiments_location = os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.dirname((nn_output_location))))
        )
        output_dir = os.path.join(
            base_experiments_location,
            "hill_climbing_search_2",
            f"Task-{hyper_params_dict['Task']}",
            f"{args.dataset}_QueryProb-{hyper_params_dict['QueryProb']}_EvidProb-{hyper_params_dict['EvidProb']}",
            args.initialization,
        )
    logger.info(f"Output directory is {output_dir}")
    args.output_dir = output_dir
    model_outputs_dir = output_dir
    output_file_path = os.path.join(
        model_outputs_dir,
        "random_sample_outputs.npz",
    )

    dataset_name = args.dataset

    args.pgm_model_path = os.path.join(
        args.spn_model_directory, f"{dataset_name}/spn.json"
    )
    device, use_cuda, use_mps = init_logger_and_wandb(project_name, run_name, args)
    args.device = device
    print(output_file_path)
    if os.path.exists(output_file_path):
        if args.replace:
            logger.info(f"Output file {output_file_path} already exists")
            logger.info("Replacing the existing file")
        else:
            logger.info(f"Output file {output_file_path} already exists")
            logger.info("Skipping the current run")
            return
    logger.info(f"Using {args.initialization} for initializing the samples")

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    prev_outputs = np.load(
        nn_output_location,
        allow_pickle=True,
    )
    buckets = prev_outputs["all_buckets"][()]
    evid_var_bool = np.array(buckets["evid"])
    query_var_bool = np.array(buckets["query"])
    unobs_var_bool = np.array(buckets["unobs"])
    percentage_query = np.sum(query_var_bool) / len(query_var_bool)
    buckets = {"evid": evid_var_bool, "query": query_var_bool, "unobs": unobs_var_bool}
    if args.initialization in ["random", "spn_approx", "mle", "sequential"]:
        data = prev_outputs["all_unprocessed_data"]
    elif args.initialization == "nn_approx":
        data = prev_outputs["all_unprocessed_data"]
        nn_outputs = prev_outputs["all_nn_outputs"]
        data[query_var_bool] = nn_outputs[query_var_bool]
    n_features = data.shape[1]
    logger.info(f"You have selected {args.dataset}")
    # Initialize the model model
    library_spn = load_spn_json(args.pgm_model_path)
    torch_spn = SPNModel(
        args.pgm_model_path,
        num_var=n_features,
        device=device,
        percent_nodes_for_features=0.8,
    )
    logger.info(f"Loaded SPN model from {args.pgm_model_path}")
    # Save the model
    (
        initial_data_points,
        initial_ll_scores,
        hc_samples,
        _,
        library_spn_marginalized,
    ) = hill_climbing_search(
        args,
        data,
        prev_outputs,
        query_var_bool,
        unobs_var_bool,
        evid_var_bool,
        library_spn,
        torch_spn,
        buckets,
    )
    if not os.path.exists(model_outputs_dir):
        # If the folder does not exist, create it
        os.makedirs(model_outputs_dir)
        print("Folder created successfully!")
    else:
        print("Folder already exists!")
    ll_score_hc_sample = log_likelihood(
        library_spn_marginalized,
        hc_samples,
        return_results=False,
    )
    initial_ll_scores = log_likelihood(
        library_spn_marginalized,
        initial_data_points,
        return_results=False,
    )
    array_for_spn = np.zeros_like(data)

    array_for_spn[query_var_bool] = np.nan
    array_for_spn[evid_var_bool] = data[evid_var_bool]
    logger.info(f"Array for spn {array_for_spn.shape}")
    mpe_output = mpe(library_spn_marginalized, array_for_spn)
    root_ll_spn = log_likelihood(
        library_spn_marginalized,
        mpe_output,
        return_results=False,
    )
    mean_ll_hc_sample = np.mean(ll_score_hc_sample)
    std_ll_hc_sample = np.std(ll_score_hc_sample)
    mean_ll_spn = np.mean(root_ll_spn)
    std_ll_spn = np.std(root_ll_spn)
    mean_ll_initial_sample = np.mean(initial_ll_scores)
    std_ll_initial_sample = np.std(initial_ll_scores)

    np.savez(
        output_file_path,
        base_method_output=initial_data_points,
        root_ll_base=initial_ll_scores,
        mean_ll_initial_sample=mean_ll_initial_sample,
        std_ll_initial_sample=std_ll_initial_sample,
        output_samples=hc_samples,
        root_ll_sample=ll_score_hc_sample,
        root_ll_spn=root_ll_spn,
        mean_ll_sample=mean_ll_hc_sample,
        mean_ll_spn=mean_ll_spn,
        std_ll_sample=std_ll_hc_sample,
        std_ll_spn=std_ll_spn,
        all_buckets=buckets,
        all_unprocessed_data=data,
        spn_mpe_output=mpe_output,
    )
    logger.info(f"Dataset {dataset_name}, Query {percentage_query}")
    logger.info(
        f"LL score for Initial data from {args.initialization} is {mean_ll_initial_sample}"
    )
    logger.info(f"LL score for Sampled data is {mean_ll_hc_sample}")
    logger.info(f"LL score for SPN is {mean_ll_spn}")
    wandb.log(
        {
            f"LL score {args.initialization} sample": mean_ll_hc_sample,
            "LL Score SPN": mean_ll_spn,
        }
    )
    wandb.alert(
        title="Hill Climbing Search Inference Completed Method 1",
        text=f"Dataset: {dataset_name}, Query: {percentage_query}, LL score {args.initialization} \n Initial: {mean_ll_initial_sample}, HC: {mean_ll_hc_sample}, SPN: {mean_ll_spn}",
    )


if __name__ == "__main__":
    # Training settings
    parser = argparse.ArgumentParser(
        description="any-MPE Experiments - Hill Climbing Search Inference"
    )
    parser.add_argument(
        "--debug", action="store_true", default=False, help="Are we in debug mode?"
    )
    parser.add_argument(
        "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
    )
    parser.add_argument(
        "--num-steps",
        type=int,
        default=5,
        metavar="N",
        help="how many number of samples to test",
    )
    parser.add_argument(
        "--num-samples",
        type=int,
        default=1,
        metavar="N",
        help="how many number of samples to test",
    )
    parser.add_argument(
        "--spn-model-directory",
        type=str,
        metavar="SPN",
        help="Location of the SPN model",
    )
    parser.add_argument(
        "--nn-model-output-directory",
        type=str,
        metavar="model_output",
        help="Location of the outputs of the trained NN",
    )
    # Add the argument for the number of layers
    parser.add_argument(
        "--initialization",
        type=str,
        default="sequential",
        choices=[
            "spn_approx",
            # "random",
            "nn_approx",
            "mle",
            "sequential",
        ],
        help="Number of layers",
    )
    parser.add_argument(
        "--no-local-search",
        action="store_true",
        default=False,
        help="Disable local search.",
    )
    parser.add_argument(
        "--no-cuda", action="store_true", default=False, help="Disable CUDA training."
    )
    parser.add_argument(
        "--no-mps",
        action="store_true",
        default=False,
        help="Disable macOS GPU training.",
    )
    parser.add_argument(
        "--replace",
        action="store_true",
        default=False,
        help="Replace the existing solutions.",
    )
    parser.add_argument(
        "--not-use-num-layers",
        action="store_true",
        default=False,
        help="Don't use directories if 'NumLayers' is present",
    )

    args = parser.parse_args()
    exp_path = args.nn_model_output_directory
    exp_list = exp_path.split("/")[-2].split("_")
    for val in exp_list:
        if val.startswith("Dataset"):
            args.dataset = val.split("-")[1]
            break

    if args.not_use_num_layers:
        if "NumLayers" in exp_path:
            exit()
    random.seed(args.seed)
    pprint(args)
    project_name = f"anympe_Random_HC"
    evaluate_random_sampling(
        args, project_name, f"HC_{args.initialization}_{args.dataset}"
    )
