import os
import json
import pickle
import numpy as np
import torch
import pandas as pd

from gp import GP
from util import normalize_data

from .basefunc import BaseFunc


class Phosphorus(BaseFunc):
    def __init__(
        self,
        xsize=400,
        noise_std=0.01,
    ):
        """__init__.

        Parameters
        ----------
        xsize : int
            the size of the discrete input domain
        noise_std : float
            the standard deviation of the noise of the GP model
            to generate the noisy observation

        """
        xsize = int(xsize)
        noise_std = float(noise_std)

        xdim = 2
        super(Phosphorus, self).__init__(xdim, xsize, noise_std=noise_std)

        self.module_name = "phosphorus"
        self.xsize = xsize
        self.x_domain = BaseFunc.generate_discrete_points(xsize, xdim)

        self.train_x, self.ys = Phosphorus.preprocess_data()

        hyperparameter_filename = (
            f"func/gp_hyperparameters/gp_hyperparameters_phosphorus.json"
        )
        with open(hyperparameter_filename, "r", encoding="utf-8") as f:
            self.hyperparameters = json.load(f)

        self.data_gp_model = GP(
            self.train_x,
            self.ys,
            initialization=self.hyperparameters,
            prior=None,
            ard=True,
        )


    @staticmethod
    def preprocess_data():
        df = pd.read_csv("dataset/phosphorus/bbarn.csv")

        df = df[df["P"] > 0]
        # drop erroneous sensor measurements such as negative Phosphorus

        train_x = np.stack([df["x"].to_numpy(), df["y"].to_numpy()]).T
        train_x = normalize_data(train_x)
        ys = df["log10P"].to_numpy()

        train_x = torch.from_numpy(train_x)
        ys = torch.from_numpy(ys)
        ys = (ys - ys.mean()) / ys.std()

        return train_x.float(), ys.float()


    def get_noiseless_observation_from_inputs(self, x):
        """get function evaluation at input

        Parameters
        ----------
        x : tensor array of size (n, self.xdim)
            inputs to be evaluated

        Returns
        -------
        val : tensor array of float32
            evaluations of the GP sample at inputs x

        """
        with torch.no_grad():
            x = x.reshape(-1, self.xdim)
            f_preds = GP.predict_f(self.data_gp_model, x)
            with torch.no_grad():
                f_means = f_preds.mean

        return f_means

    def get_noiseless_observation_from_input_idxs(self, x_idxs):
        """get function evaluation at input idxs

        Parameters
        ----------
        x_idxs : tensor array or list of int64 of shape (n,)
            indices of inputs in self.domain to be evaluated

        Returns
        -------
        val : tensor array of float32
            evaluations of the GP sample at inputs specified by x_idxs

        """
        x = self.x_domain[x_idxs, :].reshape(-1, self.xdim)
        return self.get_noiseless_observation_from_inputs(x)
