import pandas as pd
import numpy as np
import json
import argparse
import torch
from gp import GP
from func.lake_zurich import LakeZurich
from func.phosphorus import Phosphorus
from func.intel_lab_data_humidity import IntelLabHumidity
from func.intel_lab_data_temperature import IntelLabTemperature


from util import normalize_data

import matplotlib.pyplot as plt


def fitgp(
    dataset_name,
    # function_module_name,
    # task_identifier,
    # xsize=100,
    # zsize=100,
    # noise_std=0.01,
    ntrain=100,
    # max_ndata=300,
    # ard=True,
):
    if dataset_name == "lake_zurich":
        train_x, ys = LakeZurich.preprocess_data()
    elif dataset_name.startswith("intel_lab_data-temperature"):
        train_x, ys = IntelLabTemperature.preprocess_data()
    elif dataset_name.startswith("intel_lab_data-humidity"):
        train_x, ys = IntelLabHumidity.preprocess_data()
    elif dataset_name == "phosphorus":
        train_x, ys = Phosphorus.preprocess_data()
    else:
        raise Exception(f"Unknown {dataset_name}.")

    print(f"Train_x {train_x.shape} {type(train_x)}")
    print(f"Ys {ys.shape} {type(ys)}")

    prior = GP.get_default_hyperparameter_prior()
    gp_model = GP(train_x, ys, prior=prior, ard=True)

    GP.optimize_hyperparameters(
        gp_model,
        train_x,
        ys,
        learning_rate=0.1,
        training_iter=ntrain,
        verbose=False,
    )

    hyperparameters = {
        "likelihood.noise_covar.noise": gp_model.likelihood.noise_covar.noise.detach()
        .numpy()
        .tolist(),
        "mean_module.constant": gp_model.mean_module.constant.detach().numpy().tolist(),
        "covar_module.outputscale": gp_model.covar_module.outputscale.detach()
        .numpy()
        .tolist(),
        "covar_module.base_kernel.lengthscale": gp_model.covar_module.base_kernel.lengthscale.detach()
        .numpy()
        .tolist(),
    }

    # print(f"======{function_module_name}:{task_identifier}======")
    print(json.dumps(hyperparameters, indent=2))
    print("")

    hyperparamter_filename = f"func/gp_hyperparameters/gp_hyperparameters_{dataset_name}.json"

    with open(hyperparamter_filename, "w", encoding="utf-8") as f:
        json.dump(hyperparameters, f, indent=4)

    x0s, x1s = np.meshgrid(np.linspace(0.0, 1.0, 100), np.linspace(0.0, 1.0, 100))
    xs_test = torch.from_numpy(np.stack([x0s.flatten(), x1s.flatten()]).T).float()
    f_preds = GP.predict_f(gp_model, xs_test)

    with torch.no_grad():
        f_means = f_preds.mean
        f_vars = f_preds.variance
        f_stds = torch.sqrt(f_vars)

    fig, ax = plt.subplots()
    im = ax.imshow(f_means.numpy().reshape(100, 100))
    fig.colorbar(im)
    plt.show()


# """
# Noise variance tensor([0.0001], grad_fn=<AddBackward0>)
# Mean constant Parameter containing: tensor(-1.0601, requires_grad=True)
# Output scale tensor(9.7612, grad_fn=<SoftplusBackward0>)
# Lengthscale tensor([[0.8860]], grad_fn=<SoftplusBackward0>)
# """


def fit_prior(data1d_np, distribution="Gamma", lr=0.1, ntrain=1000):
    # TESTED
    # requires: data: numpy array of shape (n,)
    # returns: parameters of the distribution fitted on data1d_np
    data = torch.from_numpy(data1d_np)

    if distribution == "Gamma":
        params = {
            "concentration": torch.tensor(1.0, requires_grad=True),
            "rate": torch.tensor(1.0, requires_grad=True),
        }

        distribution = torch.distributions.gamma.Gamma(
            params["concentration"], params["rate"]
        )
    elif distribution == "Normal" or distribution == "Gaussian":
        params = {
            "mean": torch.tensor(0.0, requires_grad=True),
            "std": torch.tensor(1.0, requires_grad=True),
        }

        distribution = torch.distributions.normal.Normal(params["mean"], params["std"])
    else:
        raise Exception("Unknown distribution!")

    optimizer = torch.optim.Adam(params.values(), lr=lr)

    for _ in range(ntrain):
        optimizer.zero_grad()
        neg_loglikelihood_loss = -torch.sum(distribution.log_prob(data))
        neg_loglikelihood_loss.backward()
        optimizer.step()

    return params


def test_fit_prior():
    test_distribution = "Gamma"

    if test_distribution == "Normal":
        mean = 2.0
        std = 1.0
        distribution = torch.distributions.normal.Normal(mean, std)
    elif test_distribution == "Gamma":
        distribution = torch.distributions.Gamma(1.0, 3.0)

    samples = distribution.rsample(sample_shape=(100,))
    params = fit_prior(samples.numpy(), distribution=test_distribution)
    print(params)


if __name__ == "__main__":
    print("Optimizing GP Hyperparameters for Lake Zurich")
    fitgp("lake_zurich", ntrain=500)

    print("Optimizing GP Hyperparameters for Intel Lab Temperature")
    fitgp("intel_lab_data-temperature", ntrain=100)

    print("Optimizing GP Hyperparameters for Intel Lab Humidity")
    fitgp("intel_lab_data-humidity", ntrain=100)

    print("Optimizing GP Hyperparameters for Phosphorus")
    fitgp("phosphorus", ntrain=100)

    # with open("config.json", "r") as f:
    #     config = json.load(f)
    #     function_module_names = config["function_module_names"]

    # parser = argparse.ArgumentParser()
    # parser.add_argument(
    #     "module_name",
    #     choices=function_module_names,
    #     help="Function module name (based on module2path in converpath.py)",
    # )
    # parser.add_argument(
    #     "--noise-std",
    #     dest="noise_std",
    #     type=float,
    #     help="Standard deviation of the Gaussian noise",
    # )
    # parser.add_argument(
    #     "--max-data",
    #     dest="max_data",
    #     type=int,
    #     help="Maximum number of data to train",
    # )
    # args = parser.parse_args()

    # function_module_name = args.module_name
    # noise_std = args.noise_std if args.noise_std else 0.01
    # max_data = args.max_data if args.max_data else 100

    # if function_module_name == "branin":
    #     xsize = 100
    #     zsize = 100
    #     ard = True
    #     task_identifiers = [
    #         "original_2_2",
    #     ]

    # elif function_module_name == "goldstein":
    #     xsize = 100
    #     zsize = 100
    #     ard = True
    #     task_identifiers = [
    #         "original_2_4",
    #     ]

    # elif function_module_name.startswith("gas_transmission"):
    #     xsize = 100
    #     zsize = 100
    #     ard = True
    #     task_identifiers = [
    #         "original_2_4",
    #     ]

    # elif function_module_name.startswith("welded_beam_design"):
    #     xsize = 100
    #     zsize = 100
    #     ard = True
    #     task_identifiers = [
    #         "original_2_4",
    #     ]

    # else:
    #     xsize = 100
    #     zsize = 100
    #     ard = True
    #     task_identifiers = [
    #         "original_2_4",
    #     ]

    # for task_identifier in task_identifiers:
    #     fitgp(
    #         function_module_name,
    #         task_identifier,
    #         xsize,
    #         zsize,
    #         noise_std,
    #         ntrain=100,
    #         max_ndata=max_data,
    #         ard=ard,
    #     )
