import numpy as np
from scipy.integrate import quad

def r_per_unit(F, M, lb, ub):
    K = np.shape(F)[0]
    reward = np.zeros((K, M))
    for k in range(K):
        for n in range(M):
            reward[k, n] = quad(lambda z: (np.power(F[k].cdf(z), n) - np.power(F[k].cdf(z), n + 1)), lb[k], ub[k])[0]
    return reward


def r_n(_r_per_unit, n, p):
    return n @ np.nan_to_num(_r_per_unit[np.arange(len(n)), n + p - 1])

def e_i(i, d):
    return np.eye(1, d, i, dtype=int).ravel()


def greedy_allocation(R, p, N):
    K, _ = R.shape
    n = np.zeros(K, dtype=int)
    continue_adding_units = (np.sum(n) < N)
    while continue_adding_units:
        tentative_rewards = [r_n(R, n + e_i(k, K), p) for k in
                             range(K)]  # compare the effect of adding one unit to each arm
        k = np.argmax(tentative_rewards)  # choose the best arm to add to

        continue_adding_units *= not (np.max(tentative_rewards) < r_n(R, n, p))  # adding a unit does not degrade the reward (beware of NaN)
        if continue_adding_units:
            n += e_i(k, K)

        continue_adding_units *= (np.sum(n) < N)  # there are still some available units
    return n


def second_price_auction(bids, has_ties = True):
    if has_ties:
        potential_winner = (bids >= np.max(bids, axis=-1))
        indices = np.nonzero(potential_winner)
        winners = []
        for k in range(np.shape(bids)[0]):
            winners.append(np.random.choice(indices[1][indices[0]==k]))
        winners = np.array(winners)
    else:
        winners = bids.argmax(axis=-1)
    won_values = bids.max(axis=-1)
    payments = np.sort(bids, axis=-1)[:, -2]
    return winners, won_values, payments


class CoMABEnv:
    def __init__(self, p, F, N, has_ties = True):
        self._F = F  # cdf
        self._p = p  # number of competitors per arm
        self._P = max(self._p)  # max number of competitors
        self._N = N  # number of players
        self._K = len(self._p)  # number of arms (auctions)
        self._has_ties = has_ties

        lb, ub = np.array([f.support() for f in F]).T.tolist()

        self._r_per_unit = r_per_unit(self._F, self._N + self._P, lb, ub)
        self.n_star = greedy_allocation(self._r_per_unit, self._p, self._N)
        self.r_star = r_n(self._r_per_unit, self.n_star, self._p)

    def r(self, n):
        return r_n(self._r_per_unit, n, self._p)

    def step(self, n):
        bids = -np.inf*np.ones((self._K, self._N + self._P))
        for k, f in enumerate(self._F):
            bids[k, :(n[k] + self._p[k])] = f.rvs(size=(n[k] + self._p[k]))  # create bids from a distribution f

        winners, won_values, payments = second_price_auction(bids, self._has_ties)
        is_arm_observed = (winners >= self._p)
        observed_gains = won_values * (winners >= self._p)
        observed_costs = payments * (winners >= self._p)
        return is_arm_observed, observed_gains, observed_costs
