import os
import pickle
import numpy as np
import torch

from gp import GP

from .basefunc import BaseFunc


class Branin(BaseFunc):
    def __init__(
        self,
        xsize=10,
        noise_std=0.01,
    ):
        """__init__.

        Parameters
        ----------
        xsize : int
            the size of the discrete input domain
        noise_std : float
            the standard deviation of the noise of the GP model
            to generate the noisy observation

        """
        xsize = int(xsize)
        noise_std = float(noise_std)

        xdim = 2

        super(Branin, self).__init__(xdim, xsize, noise_std=noise_std)

        self.module_name = "branin"
        self.xsize = xsize
        self.x_domain = BaseFunc.generate_discrete_points(xsize, xdim)


    def get_noiseless_observation_from_inputs(self, x):
        """get function evaluation at input

        Parameters
        ----------
        x : tensor array of size (n, self.xdim)
            inputs to be evaluated

        Returns
        -------
        val : tensor array of float32
            evaluations of the GP sample at inputs x

        """
        with torch.no_grad():
            x = x.reshape(-1, self.xdim)
            x = 15.0 * x - torch.tensor([5.0, 0.0])

            tmp = (
                x[:, 1]
                - 5.1 * x[:, 0] ** 2 / (4 * torch.pi**2)
                + 5.0 * x[:, 0] / torch.pi
                - 6.0
            )
            val = (
                -1.0
                / 51.95
                * (
                    tmp * tmp
                    + (10.0 - 10.0 / (8.0 * torch.pi)) * torch.cos(x[:, 0])
                    - 44.81
                )
            )
            val = val.reshape(
                -1,
            )
        return val

    def get_noiseless_observation_from_input_idxs(self, x_idxs):
        """get function evaluation at input idxs

        Parameters
        ----------
        x_idxs : tensor array or list of int64 of shape (n,)
            indices of inputs in self.domain to be evaluated

        Returns
        -------
        val : tensor array of float32
            evaluations of the GP sample at inputs specified by x_idxs

        """
        x = self.x_domain[x_idxs, :].reshape(-1, self.xdim)
        return self.get_noiseless_observation_from_inputs(x)
