from sklearn.gaussian_process.kernels import ConstantKernel, Matern
from functions.bbob import *
from functools import partial
from space_gen import get_func, get_bounds
import wandb
import datetime
import numpy as np
from lamcts.utils import standardization, minmax
from sklearn.gaussian_process import GaussianProcessRegressor
from scipy.stats import norm
from torch.quasirandom import SobolEngine

def expected_improvement(gpr, X, fX, lb, ub, xi=0.0, use_ei = True):
    ''' Computes the EI at points X based on existing samples X_sample and Y_sample using a Gaussian process surrogate model.
    Args: X: Points at which EI shall be computed (m x d). X_sample: Sample locations (n x d).
    Y_sample: Sample values (n x 1). gpr: A GaussianProcessRegressor fitted to samples.
    xi: Exploitation-exploration trade-off parameter.
    Returns: Expected improvements at points X. '''
    
    assert lb is not None and ub is not None
    X = minmax(X, min=lb, max=ub)
    
    mu, sigma = gpr.predict(X, return_std=True)
    
    if not use_ei:
        return mu
    else:
        #calculate EI
        # X_sample = minmax(self.X, min=lb, max=ub)
        # mu_sample = gpr.predict(X_sample)
        mu_sample = standardization(fX.reshape((-1, 1)))
        sigma = sigma.reshape(-1, 1)
        mu_sample_opt = np.max(mu_sample)
        with np.errstate(divide='warn'):
            imp = mu - mu_sample_opt - xi
            imp = imp.reshape((-1, 1))
            Z = imp / sigma
            ei = imp * norm.cdf(Z) + sigma * norm.pdf(Z)
            ei[sigma == 0.0] = 0.0
        return ei

def train_gpr(x_observed, y_observed, lb, ub, dims):
    noise = 0.1
    m52 = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
    gpr = GaussianProcessRegressor(kernel=m52, alpha=noise**2, normalize_y=False) #default to CPU
    
    # normalize
    X  = np.asarray(x_observed).reshape(-1, dims)
    fX = np.asarray(y_observed).reshape(-1)
    X = minmax(X, min=lb, max=ub)
    fX = standardization(fX)
    
    gpr.fit(X, fX)
    return gpr
    
def simpleBO(args):
    def f(x):
        if x.ndim == 2:
            x = x[0]
        assert x.ndim == 1
        y = get_func(args.search_space_id, args.dataset_id, args.mode, args.dims)(x)*-1
        if isinstance(y, float):
            return y
        return y.item()
    
    lb, ub = get_bounds(args)
    dims = args.dims
    
    for iter in range(args.rep):
        ts = datetime.datetime.utcnow() + datetime.timedelta(hours=+8)
        ts_name = f'-ts{ts.month}-{ts.day}-{ts.hour}-{ts.minute}-{ts.second}'
        wandb.init(
            project="bbob100",
            name=f"GP-EI-{args.search_space_id}-{args.dataset_id}-{ts_name}",
            job_type="GP-EI-revise",
            tags=[f"dim={args.dims}", f"similar={args.similar}", f"search_space_id={args.search_space_id}", f"dataset_id={args.dataset_id}"]
        )
        
        x_observed = []
        y_observed = []
        
        for i in range(3):
            x = np.random.uniform(lb, ub, size = (1, dims))
            y = f(x)
            
            if x.ndim == 2:
                x = x[0]
            assert x.ndim == 1
            assert len(x) == dims
            x_observed.append(x)
            y_observed.append(y)
            
            best_value = np.max(y_observed)
            if args.mode == "real":
                curt_best_value = best_value
            else:
                curt_best_value = np.absolute(best_value)
            
            wandb.log({
                "sample counter": i+1,
                "sample value": np.absolute(y),
                "best value": curt_best_value,  # Call .item() if you need a Python number
            })

        
        for i in range(args.iteration-3):
            # surrogate model
            gpr = train_gpr(x_observed, y_observed, lb, ub, dims)
            seed   = np.random.randint(int(1e6))
            sobol  = SobolEngine(dimension = dims, scramble=True, seed=seed)
            
            X= np.random.uniform(lb, ub, size = (10000, dims) )
            # acquisition function optimization
            X_ei = expected_improvement(gpr, X, np.asarray(y_observed), lb, ub, xi=0.001, use_ei = True)
            X_ei = X_ei.reshape(len(X))
       
            indices = np.argsort(X_ei)[-1:]
            cand = X[indices]
            cand_y = f(cand)
            
            # augment the dataset
            if cand.ndim == 2:
                cand = cand[0]
            assert cand.ndim == 1
            assert len(cand) == dims
            x_observed.append(cand)
            y_observed.append(cand_y)
            print('{}  {}  {}'.format(i, cand_y, np.max(y_observed)))

            best_value = np.max(y_observed)
            curt_best_value = best_value if args.mode == "real" else np.absolute(best_value)
            wandb.log({
                "sample counter": i+4,
                "sample value": np.absolute(cand_y),
                "best value": curt_best_value,
            })
        wandb.finish()
    
    
    
