import numpy as np

from comab.reward_estimator.reward_estimator import RewardEstimator, F_hat_combi


def F_hat(z, X, t):
    _F_hat = np.sum((X > 0) * (z >= X), axis=-1) / t
    return _F_hat


def r_hat(X, t, D=100):
    K, NpPp1 = np.shape(t)
    reward = np.zeros((K, NpPp1 - 1))
    for u in range(1, D + 1):
        _F_hat = F_hat(u / D, X, t)
        reward += np.maximum(F_hat_combi(_F_hat, t)[:, :-1] - F_hat_combi(_F_hat, t)[:, 1:], 0)
    return reward / D


class WithFullEmpiricalCDF(RewardEstimator):
    def __init__(self, K, N, p, T, **kwargs):
        super().__init__(K, N, p, **kwargs)
        self._X = np.zeros((self._K, self._N + self._P + 1, T))  # samples drawn from distribution
        self._t = np.zeros((self._K, self._N + self._P + 1), dtype=int)  # number of samples (ml in overleaf)

    def update_estimator(self, n, arms_with_observation, observed_gains, observed_costs, t):
        self._X[np.arange(self._K), n + self._p, t] = observed_gains
        self._t[np.arange(self._K), n + self._p] += arms_with_observation

    def r_hat(self, D=100, **kwargs):
        return r_hat(self._X, self._t, D)
