# Copyright (c) 2023
# Copyright holder of the paper "End-to-End Meta-Bayesian Optimisation with Transformer Neural Processes".
# Submitted to NeurIPS 2023 for review.
# All rights reserved.

import os
import multiprocessing as mp

import gym
import numpy as np
import torch

from nap.environment.mip import get_cond_mip_specs

if __name__ == '__main__':
    mp.set_start_method('spawn')

    from datetime import datetime
    from nap.RL.ppo import PPO
    from nap.policies.policies import iclr2020_NeuralAF
    from gym.envs.registration import register

    rootdir = os.path.join(os.path.dirname(os.path.realpath(__file__)))
    rootdir_nap = os.path.join(os.path.dirname(os.path.realpath(__file__)), "nap")

    dims, num_dims, cat_dims, min_points, num_classes, train_datasets, valid_datasets, test_datasets, train_gps = \
        get_cond_mip_specs(rootdir)

    gp_params = {
        "cont_kern_ls": [],
        "cat_kern_ls": [],
        "ls": [],
        "lamda": [],
        "mean_const": [],
        "lik_noise": [],
    }
    for gp_path in train_gps:
        gp_model = torch.load(gp_path)
        gp_params["cont_kern_ls"].append(gp_model.covar_module.base_kernel.continuous_kern.lengthscale.detach().cpu().numpy())
        gp_params["cat_kern_ls"].append(gp_model.covar_module.base_kernel.categorical_kern.lengthscale.detach().cpu().numpy())
        gp_params["ls"].append(gp_model.covar_module.base_kernel.lengthscale.detach().cpu().numpy())
        gp_params["lamda"].append(gp_model.covar_module.base_kernel.lamda)
        gp_params["lik_noise"].append(gp_model.likelihood.noise.detach().cpu().numpy())
        gp_params["mean_const"].append(gp_model.mean_module.constant.detach().cpu().numpy())

    cont_kern_ls = np.concatenate(gp_params["cont_kern_ls"]).mean(0).tolist()
    cat_kern_ls = np.concatenate(gp_params["cat_kern_ls"]).mean(0).tolist()
    ls = np.concatenate(gp_params["ls"]).mean().item()
    lamda = np.array(gp_params["lamda"]).mean().item()
    lik_noise = np.concatenate(gp_params["lik_noise"]).mean().item()
    mean_const = np.array(gp_params["mean_const"]).mean().item()

    reward = {}
    regret = {}
    train_steps = 500
    for ckpt in range(100, train_steps+100, 100):
        ckpt -= 1 if ckpt == train_steps else 0

        # specifiy environment
        env_spec = {
            "env_id": f"MetaBO-T295-fixed-v0",
            "f_type": "MIP",
            "D": dims,
            "f_opts": {
                "kernel": "Mixture",
                "min_regret": 1e-20,
                "data": valid_datasets,
                "cat_dims": cat_dims,
                "num_classes": num_classes,
                "cont_dims": num_dims,
                "shuffle_and_cutoff": False,
                "continuous_kern_lengthscale": cont_kern_ls,
                "categorical_kern_lengthscale": cat_kern_ls,
                "outputscale": ls,
                "lamda": lamda,
                "likelihood_noise": lik_noise,
                "mean_constant": mean_const,
            },
            "features": ["posterior_mean", "posterior_std", "incumbent", "timestep_perc", "timestep", "budget"],
            "T": 295,
            "n_init_samples": 5,
            "pass_X_to_pi": False,
            "local_af_opt": False,
            "cardinality_domain": 200,
            "remove_seen_points": False,  # only True for testing
            # will be set individually for each new function to the sampled hyperparameters
            "kernel": "Mixture",
            "kernel_lengthscale": None,  # kernel_lengthscale,
            "kernel_variance": None,  # kernel_variance,
            "noise_variance": None,  # noise_variance,
            "use_prior_mean_function": False,
            "reward_transformation": "neg_log10"  # true maximum not known
        }

        # specify PPO parameters
        # 1 iteration will run n_seeds seeds so e.g. running 5 iterations with n_seeds=10 will run 50 seeds per test dataset
        n_iterations = 5
        n_workers = 5
        n_seeds = 10  # number of seeds per test task
        batch_size = len(valid_datasets) * env_spec["T"] * n_seeds

        arch_spec = 4 * [200]
        ppo_spec = {
            "batch_size": batch_size,
            "max_steps": n_iterations * batch_size,
            "minibatch_size": batch_size // 20,
            "n_epochs": 4,
            "lr": 1e-4,
            "epsilon": 0.15,
            "value_coeff": 1.0,
            "ent_coeff": 0.0,
            "gamma": 0.98,
            "lambda": 0.98,
            "loss_type": "GAElam",
            "normalize_advs": True,
            "n_workers": n_workers,
            "env_id": env_spec["env_id"],
            "seed": 0,
            "argmax": True,
            "env_seeds": list(range(n_workers)),
            "policy_options": {
                "activations": "relu",
                "arch_spec": arch_spec,
                "exclude_t_from_policy": True,
                "exclude_T_from_policy": True,
                "use_value_network": True,
                "t_idx": -2,
                "T_idx": -1,
                "arch_spec_value": arch_spec
            }
        }

        # after Model pretraining
        ppo_spec.update({
            "load": True,
            "load_path": f"nap/log/TRAIN/MIP/MetaBO-fixed-v0/2023-04-15-10-58-12/",
            "param_iter": str(ckpt),
        })

        # register environment
        register(
            id=env_spec["env_id"],
            entry_point="nap.environment.function_gym:MetaBOEnv",
            max_episode_steps=env_spec["T"],
            reward_threshold=None,
            kwargs=env_spec
        )

        # log data and weights go here, use this folder for evaluation afterwards
        logpath = os.path.join(rootdir_nap, "log/VALIDATE/", "MIP", env_spec["env_id"], f"{ckpt}_ckpt", datetime.strftime(datetime.now(), "%Y-%m-%d-%H-%M-%S"))

        # set up policy
        policy_fn = lambda observation_space, action_space, deterministic: iclr2020_NeuralAF(observation_space=observation_space,
                                                                                    action_space=action_space,
                                                                                    deterministic=True if ppo_spec["argmax"] else deterministic,
                                                                                    options=ppo_spec["policy_options"])

        # do testing
        print("Testing on {}.\nFind logs, weights, and learning curve at {}\n\n".format(env_spec["env_id"], logpath))

        ppo = PPO(policy_fn=policy_fn, params=ppo_spec, logpath=logpath, save_interval=100)
        ppo.test()

        reward[ckpt] = np.array(ppo.teststats['avg_ep_reward']).mean(), np.array(ppo.teststats['avg_ep_reward']).std()
        regret[ckpt] = np.array(ppo.teststats['regret']).mean(), np.array(ppo.teststats['regret']).std()
        del gym.envs.registration.registry.env_specs[env_spec["env_id"]]

    print('======================================================')
    print('======================= DONE =========================')
    print('======================================================')
    print("========== REWARD ==========")
    for k in reward:
        print(f'ckpt={k} mean={reward[k][0]:.5f} std={reward[k][1]:.5f}')
    print("========== REGRET ==========")
    for k in regret:
        print(f'ckpt={k} mean={regret[k][0]:.5f} std={regret[k][1]:.5f}')

    mean_regrets = np.array([regret[k][0] for k in regret])
    mean_regrets_ckpt = np.array([k for k in regret])
    print(f"Best regret ckpt. is {mean_regrets_ckpt[np.argmin(mean_regrets)]}")
    print(f"Best Regret mean={regret[mean_regrets_ckpt[np.argmin(mean_regrets)]][0]:.5f} "
          f"std={regret[mean_regrets_ckpt[np.argmin(mean_regrets)]][1]:.5f}")
    statsfile = os.path.join(rootdir_nap, "log/VALIDATE/", "MIP", env_spec["env_id"], 'validation.txt')
    with open(statsfile, 'w') as f:
        f.write("========== REWARD ==========\n")
        for k in reward:
            f.write(f'ckpt={k} mean={reward[k][0]:.5f} std={reward[k][1]:.5f}\n')
        f.write("========== REGRET ==========\n")
        for k in regret:
            f.write(f'ckpt={k} mean={regret[k][0]:.5f} std={regret[k][1]:.5f}\n')
        f.write(f"\nBest regret ckpt. is {mean_regrets_ckpt[np.argmin(mean_regrets)]}\n")
        f.write(f"Best Regret mean={regret[mean_regrets_ckpt[np.argmin(mean_regrets)]][0]:.5f} "
                f"std={regret[mean_regrets_ckpt[np.argmin(mean_regrets)]][1]:.5f}")