import numpy as np

from .base import AbstractAgent


class AgentAlg1(AbstractAgent):
    def reset_learning(self):
        self.wij = np.zeros((self.env.K, self.env.K))   # i wins against j
        self.nij = np.zeros((self.env.K, self.env.K))   # (i,j) is played
        self.si = np.zeros(self.env.K)                  # score of arm i
        self.B = -1
        self.current_arm = 0

    def sample_action(self):
        a = self.current_arm
        qi = np.zeros(self.env.K) + np.inf
        ind = self.nij[a, :] != 0
        qi[ind] = 1 - self.wij[a, ind] / self.nij[a, ind]
        ft = 2 * np.log(1 + (self.t+2) * np.square(np.log(self.t+2)))
        ui = np.zeros(self.env.K) + np.inf
        ui[ind] = qi[ind] + np.sqrt((2) * np.log(ft) / self.nij[a, ind])
        ui[a] = -np.inf
        b = np.random.choice(np.argwhere(ui == np.max(ui)).flatten())
        return (a, b)

    def learn(self, action, observation):
        self.wij[action[0], action[1]] += observation
        self.nij[action[0], action[1]] += 1
        self.si[self.current_arm] += observation - 1/2
        if self.si[self.current_arm] <= self.B:
            self.current_arm = (self.current_arm + 1) % self.env.K
            if self.current_arm == 0:
                self.B *= 2
