import numpy as np

from .base import AbstractAgent


class AgentWrExp3IX(AbstractAgent):
    def reset_learning(self):
        self.B = 1
        self.ni = 0
        self.current_arm = 0
        self.p = np.ones(self.env.K) / (self.env.K-1)
        self.p[self.current_arm] = 0
        self.loss_estimates = np.zeros(self.env.K)
        # self.wij = np.zeros((self.env.K, self.env.K))   # i wins against j
        # self.nij = np.zeros((self.env.K, self.env.K))   # (i,j) is played
        # self.si = np.zeros(self.env.K)                  # score of arm i
        # self.B = -1
        # self.current_arm = 0

    def sample_action(self):
        a = self.current_arm
        b = np.random.choice(range(self.env.K), p=self.p)
        return (a, b)

    def learn(self, action, observation):
        # print(action, observation)
        self.ni += 1
        eta = np.sqrt(np.log(self.env.K) / (self.env.K * self.ni))
        # eta = np.sqrt(np.log(self.env.K) / (self.env.K * self.T))
        gamma = eta / 2

        loss_esetimate = (observation-(1/2)) / (self.p[action[1]] + gamma)
        self.loss_estimates[action[1]] += loss_esetimate

        wi = np.exp(-eta * self.loss_estimates)
        wi[self.current_arm] = 0
        self.p = wi / np.sum(wi)

        if np.sum(self.loss_estimates) / np.sqrt(self.ni) <= -self.B:
            # print(np.sum(self.loss_estimates) / np.sqrt(self.ni))
            self.loss_estimates = np.zeros(self.env.K)
            self.ni = 0
            self.current_arm = (self.current_arm + 1) % self.env.K
            self.p = np.ones(self.env.K) / (self.env.K-1)
            self.p[self.current_arm] = 0
            if self.current_arm == 0:
                self.B = self.B * 2
            # print(self.current_arm)
