import numpy as np

from comab.algo.baselines import CoMABAlgo


class UCB1(CoMABAlgo):
    def __init__(self, K, N, p, c, **kwargs):
        super().__init__(K, N, p)
        assert K == 1
        self.sum_reward = np.zeros(self._N + 1)
        self.t_k = np.zeros(self._N + 1, dtype=int)
        self.c = c
        self.leader = int(self._N / 2)
        self.n[0] = self.leader
        self.ucb = np.zeros(self._N + 1)

    @property
    def r_n(self):
        return np.clip(np.nan_to_num(self.sum_reward / self.t_k), 0, 1)

    def update(self, arms_with_observation, observed_gains, observed_costs, t):
        # update
        self.sum_reward[self.n[0]] += observed_gains[0] - observed_costs[0]
        self.t_k[self.n[0]] += 1
        self.ucb = self.r_n + np.sqrt(2 * np.log(t) / self.t_k)
        self.leader = np.argmax(self.ucb)
        self.n[0] = self.leader
