"""
allow learning multiple isets (e.g., contour at different quantiles)
"""

import os
import sys
import numpy as np
import torch
import argparse
import pickle

from configurations import BLACKBOX_FUNCTIONS, ACQUISITION_FUNCTIONS, MODES
import util
from gp import GP


parser = argparse.ArgumentParser()

parser.add_argument("--num-rand", dest="nrand", type=int, help="Number of random runs")
parser.add_argument(
    "--num-iter", dest="niter", type=int, help="Number of learning iterations"
)
parser.add_argument(
    "--num-init-observation",
    dest="n_init_obs",
    type=int,
    help="Number of learning iterations",
)
parser.add_argument(
    "--initseed",
    dest="initseed",
    type=int,
    help="random seed to generate initial observations",
)
parser.add_argument(
    "--bbfunc",
    type=str,
    help="Blackbox function name and optionally key,value of function parameters such as input dimension, e.g., branin___key1__val1___key2__val2",
)
parser.add_argument(
    "--acqfunc",
    choices=ACQUISITION_FUNCTIONS,
    default=ACQUISITION_FUNCTIONS[0],
    help="acquisition_function_name",
)
parser.add_argument(
    "--iset",
    "--names-list",
    nargs="+",
    default=["maximizer"],
    help="type of the interesting set",
)
parser.add_argument(
    "--mode",
    choices=MODES,
    default=MODES[0],
    help="boundary, cl_ord, cl_val",
)
parser.add_argument(
    "--plot",
    type=int,
    choices=[0, 1],
    default=0,
    help="1 if plot iteration-by-iteration GP posterior, 0 otherwise",
)
parser.add_argument(
    "--plotfilename", type=str, help="filename to save the iteration-by-iteration plot"
)

args = parser.parse_args()

bbfunc = args.bbfunc
initseed = args.initseed
acqfunc = args.acqfunc

iset_types = args.iset
# iset_types = [iset_type]  # synthetically generate multiple isets

nrand = args.nrand
niter = args.niter
n_init_obs = args.n_init_obs
mode = args.mode

is_show = args.plot
plot_filename = args.plotfilename
if plot_filename:
    plot_filename = f"img/{plot_filename}_{bbfunc}"

bbfunc_name, bbfunc_params = util.get_func_name_and_params(bbfunc)
print(f"Blackbox function: {bbfunc_name}, with params: {bbfunc_params}")

bbfunc_obj = util.get_function_instance(bbfunc_name, bbfunc_params)
x_domain = bbfunc_obj.x_domain
xsize = bbfunc_obj.xsize

# for visualization only
gt_func_evals = bbfunc_obj.get_noiseless_observation_from_input_idxs(
    list(range(bbfunc_obj.xsize))
)

gt_iset_idxs_collection = {}
gt_iset_collection = {}
gt_iset_evals_collection = {}

for iset_type in iset_types:
    gt_iset_idxs_collection[iset_type] = bbfunc_obj.iset_idxs(iset_type)
    gt_iset_collection[iset_type] = bbfunc_obj.iset(iset_type)
    gt_iset_evals_collection[iset_type] = bbfunc_obj.iset_evals(iset_type)
    print(
        f"Ground truth {iset_type}: {gt_iset_collection[iset_type]}, evaluation: {gt_iset_evals_collection[iset_type]}."
    )

gt_gp_hyperparameters = bbfunc_obj.gp_hyperparameters

regret_file_to_save = f"out/{bbfunc}_REGRET_{'-'.join(iset_types)}_{acqfunc}_nr{nrand}_it{niter}_seed{initseed}_init{n_init_obs}_{mode}.txt"
all_regret_file_to_save = f"out/{bbfunc}_REGRET_{'-'.join(iset_types)}_{acqfunc}_nr{nrand}_it{niter}_seed{initseed}_init{n_init_obs}_{mode}.pkl"

all_regrets_collection = {}
for iset_type in iset_types:
    all_regrets_collection[iset_type] = np.ones([nrand, niter]) * 100.0
all_regrets = np.ones([nrand, niter]) * 100.0

for tr in range(nrand):
    seed = initseed + tr

    np.random.seed(seed)
    torch.manual_seed(seed)

    # Initial observations
    init_X, init_y = bbfunc_obj.get_init_observations(n_init_obs, seed=seed)

    # Initialize GP models
    gp_model = GP(
        init_X, init_y, initialization=gt_gp_hyperparameters, prior=None, ard=True
    )

    # Some placeholders: input_queries, observations, regrets
    obs_X = init_X.detach().clone()
    obs_y = init_y.detach().clone()

    # previous_preds = []
    # add_to_previous_preds = lambda pred, regret: previous_preds.append(
    #     {"pred": pred, "regret": regret}
    # )

    for t in range(niter):
        print(
            f"\n***********************************\n* Random run {tr:2d} *** Iteration {t:3d} *\n-----------------------------------"
        )

        gp_model.set_train_data(obs_X, obs_y, strict=False)

        beta = 2.0 * np.log(xsize * (t + 1) ** 2 * np.pi**2 / 6 / 0.1) / 10
        print(f"beta = {beta}")

        f_preds = GP.predict_f(gp_model, x_domain)
        f_means = f_preds.mean
        f_vars = f_preds.variance
        f_stds = torch.sqrt(f_vars)

        upper_f = (f_means + beta * f_stds).reshape(
            xsize,
        )
        lower_f = (f_means - beta * f_stds).reshape(
            xsize,
        )

        iset_idxs_collection = {}

        # get all estimated xstars
        for iset_type in iset_types:
            iset_idxs = util.get_iset_idxs(f_means, iset_type)
            iset_idxs_collection[iset_type] = iset_idxs

            print(f"({iset_type}) iset_idxs = {iset_idxs}")
            print(f"({iset_type}) iset = {x_domain[iset_idxs]}")
            print(
                f"({iset_type}) iset eval = {bbfunc_obj.get_noiseless_observation_from_input_idxs(iset_idxs)}"
            )
            print(
                f"ground truth {iset_type}: {gt_iset_collection[iset_type]}, eval: {gt_iset_evals_collection[iset_type]}"
            )

        # compute the regret
        if mode == "boundary":
            all_regrets[tr, t] = -1e9
            for iset_type in iset_types:
                gt_set = bbfunc_obj.get_noiseless_observation_from_input_idxs(
                    iset_idxs_collection[iset_type]
                )
                lt_set = bbfunc_obj.get_noiseless_observation_from_input_idxs(
                    util.get_complement_set(xsize, iset_idxs_collection[iset_type])
                )
                prediction_regret = util.get_regret_between_2_sets(
                    gt_set, lt_set, mode="sum"
                )
                all_regrets_collection[iset_type][tr, t] = prediction_regret
                all_regrets[tr, t] = max(all_regrets[tr, t], prediction_regret)
                print(f"{iset_type}: Prediction regret = {prediction_regret}")

        else:
            raise Exception(f"Haven't implemented the regret for mode: {mode}")

        max_regret_upper_bound = -1e9
        selected_imax_idx = None
        selected_rmax_idx = None

        for iset_type in iset_types:
            # select the input query of the most uncertain iset (highest upper bound of the regret)
            if mode == "cl_val":
                imax_idx, rmax_idx = util.get_pair_in_cl_val_mode(
                    lower_f, upper_f, iset_idxs
                )

            elif mode == "boundary":
                imax_idx, rmax_idx = util.get_pair_in_boundary_mode(
                    lower_f, upper_f, iset_idxs
                )

            elif mode == "cl_ord":
                imax_idx, rmax_idx = util.get_pair_in_cl_ord_mode(
                    lower_f, upper_f, iset_idxs
                )

            else:
                raise Exception(f"Unknown mode: {mode}. Only accept modes in {MODES}!")

            regret_upper_bound = upper_f[imax_idx] - lower_f[rmax_idx]
            if f_means[imax_idx] >= f_means[rmax_idx]:
                regret_upper_bound = upper_f[rmax_idx] - lower_f[imax_idx]

            if max_regret_upper_bound < regret_upper_bound:
                max_regret_upper_bound = regret_upper_bound
                selected_imax_idx = imax_idx
                selected_rmax_idx = rmax_idx

        max_upper_idx = selected_rmax_idx
        if upper_f[imax_idx] > upper_f[rmax_idx]:
            max_upper_idx = selected_imax_idx

        min_lower_idx = selected_rmax_idx
        if lower_f[imax_idx] < lower_f[rmax_idx]:
            min_lower_idx = selected_imax_idx

        # the pair to query is (imax_idx, rmax_idx)
        # where imax_idx is in iset_idxs
        #       and rmax_idx is in remaining_set_idxs
        if acqfunc == "max_upper":
            query_idx = max_upper_idx

        elif acqfunc == "min_lower":
            query_idx = min_lower_idx

        elif acqfunc == "tightest":
            query_idx = max_upper_idx
            if (
                upper_f[max_upper_idx] - lower_f[max_upper_idx]
                > upper_f[min_lower_idx] - lower_f[min_lower_idx]
            ):
                query_idx = min_lower_idx

        elif acqfunc == "uncertainty":
            query_idx = selected_imax_idx
            if (
                upper_f[selected_rmax_idx] - lower_f[selected_rmax_idx]
                > upper_f[selected_imax_idx] - lower_f[selected_imax_idx]
            ):
                query_idx = selected_rmax_idx

        elif acqfunc == "exprdesign":
            query_idx = torch.argmax(upper_f - lower_f).item()

        elif acqfunc == "rand":
            query_idx = np.random.randint(0, xsize)

        elif acqfunc == "bo_ei":
            max_obs_y = torch.max(obs_y)
            expected_improvement = (f_means - max_obs_y) * (
                torch.tensor(1.0)
                - torch.distributions.normal.Normal(f_means, f_stds).cdf(max_obs_y)
            ) + f_vars * torch.exp(
                torch.distributions.normal.Normal(f_means, f_stds).log_prob(max_obs_y)
            )
            query_idx = torch.argmax(expected_improvement)

        elif acqfunc == "bo_pi":
            max_obs_y = torch.max(obs_y)
            improved_prob = torch.tensor(1.0) - torch.distributions.normal.Normal(
                f_means, f_stds
            ).cdf(max_obs_y)
            query_idx = torch.argmax(improved_prob)

        elif acqfunc == "bo_mes":
            f_sample = f_preds.sample()
            sample_max = torch.max(f_sample)

            mes_pdf = torch.exp(torch.distributions.normal.Normal(f_means, f_stds).log_prob(sample_max))
            mes_cdf = torch.distributions.normal.Normal(f_means, f_stds).cdf(sample_max)
            mes = sample_max * mes_pdf / torch.tensor(2.0) / mes_cdf  - torch.log(mes_cdf)
            query_idx = torch.argmax(mes)

        else:
            raise Exception(f"Unknown acquisition function {acqfunc}")

        query_x = x_domain[query_idx, :].reshape(1, -1)
        query_y = bbfunc_obj.get_noisy_observation_from_input_idxs([query_idx]).reshape(
            1,
        )

        if bbfunc_obj.xdim == 1 and (is_show or plot_filename is not None):
            if plot_filename is not None:
                save_filename = f"{plot_filename}_{tr}_{t:3d}.png"
                save_filename = save_filename.replace(' ', '0')

            util.plot_1d(
                x_domain,
                gt_func_evals,
                f_means,
                upper_f,
                lower_f,
                iset_idxs_collection[iset_types[0]],
                gt_iset_idxs_collection[iset_types[0]],
                selected_imax_idx if acqfunc not in ['exprdesign', 'rand'] else None,
                selected_rmax_idx if acqfunc not in ['exprdesign', 'rand'] else None,
                query_idx,
                query_y.squeeze(),
                obs_X.numpy(),
                obs_y.numpy(),
                iteration=t,
                is_show=is_show,
                # is_show= (t >= 10),
                save_filename=save_filename,
                is_show_imax_rmax=(acqfunc not in ["rand", "exprdesign"]),
                # show_legend = (t >= 10),
            )

        print(f"Query idx: {query_idx}, x: {query_x}, y: {query_y}")

        obs_X = torch.cat([obs_X, query_x], dim=0)
        obs_y = torch.cat([obs_y, query_y], dim=0)


np.savetxt(regret_file_to_save, all_regrets, delimiter=",")
with open(all_regret_file_to_save, "wb") as file:
    pickle.dump(all_regrets_collection, file, protocol=pickle.HIGHEST_PROTOCOL)
