import GPy
from bayesian_optimization_bts_red import BTS_RED
import pickle
import numpy as np

max_iter = 70

ls, ls_noise_var = 0.04, 0.15
noise_var_min, noise_var_max = 1e-4, 0.2

log_file_name = "obj_funcs/synth_func.pkl"
all_func_info = pickle.load(open(log_file_name, "rb"))
domain = all_func_info["domain"]
f = all_func_info["f"]
f_noise_var = all_func_info["f_noise_var"]

def synth_func(param, n_t):
    x = param[0]
    ind = np.argmin(np.abs(domain - x))
    return np.mean(np.random.normal(f[ind], np.sqrt(f_noise_var[ind]), n_t)), f[ind], f_noise_var[ind]

batch_size = 50

ratio = (np.sqrt(batch_size) + 1) / (batch_size - 1)
R2 = noise_var_max * ratio * 0.3
# R2 = noise_var_max * ratio * 0.2


### this is a boolean variable used to indicate whether we want to fix n_t
fix_nt_flag = False
# fix_nt_flag = True

# fix_nt_value = 1
# fix_nt_value = 5
# fix_nt_value = 10
fix_nt_value = 20
if fix_nt_flag:
    R2 = noise_var_max / fix_nt_value


beta_t = np.ones(5000)
n_min, n_max = 2, 50

#### we use a fixed n_t during initialization for every queried initial input
fix_nt_init = 10
init_size = 10 # number of initial input

### M_TS is the number of random features we use to approximately run TS
M_TS = 50

### we optimize the GP hyperparameters after every gp_opt_schedule iterations
gp_opt_schedule = 5


run_list = np.arange(50)

for itr in run_list:
    if not fix_nt_flag:
        log_file_name = "results_bts_red_known/res_ls_" + str(ls) + "_ls_noise_var_" + str(ls_noise_var) + \
            "_noise_range_" + str(noise_var_min) + "_" + str(noise_var_max) + "_iter_" + str(itr) + \
            "_batch_size_" + str(batch_size) + "_R2_" + str(R2) + "_init_" + str(init_size) + ".pkl"
    else:
        log_file_name = "results_bts_red_known/res_ls_" + str(ls) + "_ls_noise_var_" + str(ls_noise_var) + \
            "_noise_range_" + str(noise_var_min) + "_" + str(noise_var_max) + "_iter_" + str(itr) + \
            "_batch_size_" + str(batch_size) + "_R2_" + str(R2) + "_nt_" + str(fix_nt_value) + "_init_" + str(init_size) + ".pkl"

    bo_ts = BTS_RED(f=synth_func, pbounds={'x1':(0, 1)}, gp_opt_schedule=gp_opt_schedule, log_file=log_file_name, M_TS=M_TS, \
               n_min=n_min, n_max=n_max, noise_var_func=f_noise_var, domain=domain, \
               batch_size=batch_size, R2=R2, beta_t=beta_t, \
               fix_nt_flag=fix_nt_flag, fix_nt_value=fix_nt_value, 
               use_init="inits/init_itr_" + str(itr) + "_init_" + str(init_size) + ".p", save_init=False, save_init_file=None, \
               T=max_iter, fix_nt_init=fix_nt_init)
    bo_ts.maximize(n_iter=max_iter, init_points=init_size)
