import numpy as np


class GP:
    # X_train is a 1D list of action IDs for the observed data so far
    def __init__(
        self,
        actions,  # list of all action IDs; should be a list of integers from 0 to n-1
        similarity_matrix: object,  # similarity matrix given to the agent; replaces kernel function
        sigma: float = 1.0,  # standard deviation of ground truth distribution
    ):
        self.similarity_matrix = similarity_matrix
        self.sigma = sigma
        self.actions = actions

    # Inputs:
    #   x (int): a single action ID for which we want to obtain a predicted reward distribution.
    # Returns mean, variance of Normal distribution over predicted mean reward
    def predict(self, x):
        if not hasattr(self, "X_train"):
            # Handling the case where there is no observed data
            return 0, 1

        # When we have training data, stored in X_train and y_train
        # Note: X_train is a list of action IDs
        # kernel vector: k[n] = similarity(new_action (x), action_n)
        k = np.array([self.similarity_matrix[x, int(action)] for action in self.X_train])
        # covariance matrix computed across the training actions
        n_train = len(self.X_train)
        C = np.zeros((n_train, n_train))
        for i in range(n_train):
            for j in range(n_train):
                # C[n, m] = similarity(action_n, action_m) + sigma * (1 if n=m, 0 otherwise)
                new_val = self.similarity_matrix[int(self.X_train[i]), int(self.X_train[j])]
                if i == j:
                    new_val += self.sigma
                
                C[i, j] = new_val  # regular Numpy implementation
                # C = C.at[i, j].set(new_val)  # Jax implementation

        C_inv = np.linalg.inv(C)
        kT = np.transpose(k)

        # Compute mean and variance given the covariance matrix and kernel vector
        mu = kT @ C_inv @ self.y_train
        c = self.similarity_matrix[x, x] + self.sigma  # the constant defined as k(x, x) + sigma
        s2 = c - kT @ C_inv @ k

        # Return mean and variance of the Normal distribution over predicted rewards for action x
        return mu, s2


    # Inputs:
    #   X: the action IDs (NOT features!) of the data
    #   y: the true observed rewards for the taken actions
    def fit(self, X, y):
        # Save the training data
        self.X_train = X
        self.y_train = y


    def select_action(self, allowed_actions):
        # 1. Sample expected rewards from all predicted reward distributions.
        # Define the params: mu -> means per action, s2 -> variances per action
        mu, sigma = [], []

        for action in self.actions:
            # Sample
            new_mu, new_s2 = self.predict(action)
            mu.append(new_mu)
            sigma.append(np.sqrt(abs(new_s2)))

        samples = np.random.normal(mu, sigma)

        # 2. Choose the action with the highest sampled value. Limit to the allowed actions.
        # Note: argmax returns only the index of the first max action, so in case of ties, this still works
        best_action = np.where(samples == np.max(samples[list(allowed_actions)]))[0][0]

        # Return the action. 
        return best_action


