import os
import pickle
import numpy as np
import torch

from func.basefunc import BaseFunc
from gp import GP


class GPSample(BaseFunc):
    def __init__(
        self,
        xdim=1,
        xsize=10,
        noise_std=0.01,
        seed=0,
    ):
        """initialize the GPSample object

        Note
        ----
        Only applied to discrete input domain

        Parameters
        ----------
        xsize : int
            the size of the discrete input domain
        noise_std : float
            the standard deviation of the noise of the GP model
            to generate the noisy observation
        seed : int
            the random seed to generate the GP sample

        """
        xdim = int(xdim)
        xsize = int(xsize)
        noise_std = float(noise_std)
        seed = int(seed)

        super(GPSample, self).__init__(xdim, xsize, noise_std=noise_std)

        # will overwrite the gp_hyperparameters getter of BaseFunc
        self._gp_hyperparameters = None

        self.module_name = "gp_sample"
        self.xsize = xsize
        self.x_domain = BaseFunc.generate_discrete_points(xsize, xdim)
        self.generate_gp_sample(n_rand_obs=3, seed=seed, force_generate=False)

    @property
    def gp_hyperparameters(self):
        """gp_hyperparameters.

        Note
        ----
        This is used when we want to do experiment assuming knowing GP hyperparameters

        Returns
        -------
        dictionary
            optimized GP hyperparameters of the function
            e.g.,
            {
                "likelihood.noise_covar.noise": 0.01,
                "covar_module.base_kernel.lengthscale": [0.1] * self.xdim,
                "covar_module.outputscale": 1.0,
                "mean_module.constant": 0.0,
            }
        """
        return self._gp_hyperparameters


    def generate_gp_sample(self, n_rand_obs=3, seed=0, force_generate=False):
        """Generate a GP sample to stored at
           self.hyperparameters (GP hyperparameters)
           and self.evaluation (evaluations of GP sample at self.x_domain)

        Note
        ----
        Given a seed, the generated GP sample
            is stored locally at f"func/gp_sample_data/gp_sample_xdim_{self.xdim}_xsize_{self.xsize}_seed_{seed}.pkl"
        The GP sample is only re-generated if the above file does not exist
            or the argument force_generate=True

        Parameters
        ----------
        n_rand_obs : int
            the number of random initial observations
            to generate the GP sample
        seed : int
            random seed to generate the sample
        force_generate : bool
            if force_generate is True: always generate the GP sample
            else: only generate the GP sample if f"func/gp_sample_data/gp_sample_xdim_{self.xdim}_xsize_{self.xsize}_seed_{seed}.pkl"
                  does not exist.
        """
        filename = f"func/gp_sample_data/gp_sample_xdim_{self.xdim}_xsize_{self.xsize}_seed_{seed}.pkl"

        if os.path.isfile(filename):
            with open(filename, "rb") as file:
                data = pickle.load(file)
                self._gp_hyperparameters = data["hyperparameters"]
                self.evaluation = torch.tensor(data["evaluation"], dtype=torch.float32)
                assert (
                    len(self.evaluation) == self.xsize
                ), f"Loaded evaluation from {filename} must have {self.xsize} elements"
        else:
            torch.manual_seed(seed)

            # fix some hyperparameters to generate the GP
            hyperparameters = {
                "likelihood.noise_covar.noise": 0.01,
                "covar_module.base_kernel.lengthscale": [0.1] * self.xdim,
                "covar_module.outputscale": 1.0,
                "mean_module.constant": 0.0,
            }

            X_idx = torch.randint(low=0, high=self.xsize, size=(n_rand_obs,))
            X = self.x_domain[X_idx, :]
            y = torch.rand(n_rand_obs) * 2.0 - 1.0

            gp_model = GP(X, y, initialization=hyperparameters, prior=None, ard=True)

            f_preds = GP.predict_f(gp_model, self.x_domain)
            with torch.no_grad():
                f_samples = f_preds.sample(
                    sample_shape=torch.Size(
                        [
                            1,
                        ]
                    )
                )

            hyperparameters = {
                "likelihood.noise_covar.noise": gp_model.likelihood.noise_covar.noise.item(),
                "covar_module.base_kernel.lengthscale": gp_model.covar_module.base_kernel.lengthscale.detach().numpy(),
                "covar_module.outputscale": gp_model.covar_module.outputscale.item(),
                "mean_module.constant": gp_model.mean_module.constant.item(),
            }

            evaluation = f_samples.numpy().squeeze()

            with open(filename, "wb") as file:
                pickle.dump(
                        {"hyperparameters": hyperparameters, "evaluation": evaluation, "gp": gp_model},
                    file,
                    protocol=pickle.HIGHEST_PROTOCOL,
                )

            self._gp_hyperparameters = hyperparameters
            self.evaluation = torch.tensor(evaluation, dtype=torch.float32)


    def get_noiseless_observation_from_inputs(self, x_idxs):
        """get function evaluation at input idxs

        Parameters
        ----------
        x_idxs : tensor array of int64
            indices of inputs in self.domain to be evaluated

        Returns
        -------
        val : tensor array of float32
            evaluations of the GP sample at inputs specified by x_idxs

        """
        raise Exception("This method is not implemented for GPSample!")


    def get_noiseless_observation_from_input_idxs(self, x_idxs):
        """get function evaluation at input idxs

        Parameters
        ----------
        x_idxs : tensor array or list of int64 of shape (n,)
            indices of inputs in self.domain to be evaluated

        Returns
        -------
        val : tensor array of float32
            evaluations of the GP sample at inputs specified by x_idxs

        """
        with torch.no_grad():
            val = self.evaluation[x_idxs].reshape(
                -1,
            )
        return val
