from abc import ABC, abstractmethod

import numpy as np


class RewardEstimator(ABC):
    def __init__(self, K, N, p, **kwargs):
        self._K = K  # number of arms (auctions)
        self._N = N  # number of players
        self._p = p  # number of competitors per arm
        self._P = max(self._p)

    @abstractmethod
    def update_estimator(self, n, arms_with_observation, observed_gains, observed_costs, t):
        pass

    @abstractmethod
    def r_hat(self, **kwargs):
        pass

def F_hat_combi(_F_hat, t):
    NpPp1 = np.shape(t)[1]
    my_range = np.arange(NpPp1, dtype=int)

    my_pow = my_range[:, np.newaxis] @ (1 / my_range[np.newaxis, :])  # power j/i

    _F_i2j = np.power(_F_hat[..., np.newaxis], my_pow.T[np.newaxis, ...])  # shape (K, i:N+P+1, j:N+P+1)    # F^{j/i}

    weighted_sum_F = _F_i2j * t[..., np.newaxis]  # shape (K, i:N+P+1, j:N+P+1)             # m_l F^{j/i}
    cum_weighted_sum_F = np.nan_to_num(weighted_sum_F, nan=0.0, posinf=0.0, neginf=0.0).cumsum(axis=1)
    lb_idx = np.ceil(my_range / 4).astype(int)  # defining j/4
    lb_idx = np.maximum(lb_idx - 1, 0)
    ub_idx = np.minimum(2 * my_range, NpPp1 - 1).astype(int)  # defining 2j
    windowed_weighted_sum_F = cum_weighted_sum_F[:, ub_idx, my_range] - cum_weighted_sum_F[:, lb_idx, my_range]
    windowed_weights = t.cumsum(axis=-1)[:, ub_idx] - t.cumsum(axis=-1)[:, lb_idx]
    res = windowed_weighted_sum_F / windowed_weights
    return res