import numpy as np

from comab.algo.baselines import CoMABAlgo


class EXP3(CoMABAlgo):
    def __init__(self, K, N, p, gamma=0.07, **kwargs):
        super().__init__(K, N, p)
        self.num_actions = N+1
        self.gamma = gamma
        self.weights = np.ones(self.num_actions)

    def get_probabilities(self):
        total_weight = np.sum(self.weights)
        probabilities = (1 - self.gamma) * (self.weights / total_weight) + self.gamma / self.num_actions
        return probabilities

    def get_action(self):
        probabilities = self.get_probabilities()
        action = np.random.choice(np.arange(self.num_actions), p=probabilities)
        return action

    def _update(self, action, reward):
        estimated_reward = reward / self.get_probabilities()[action]
        self.weights[action] *= np.exp(self.gamma * estimated_reward / self.num_actions)

    def update(self, arms_with_observation, observed_gains, observed_costs, t):
        if arms_with_observation[0] or self.n[0] == 0:
            self._update(self.n[0], observed_gains[0] - observed_costs[0])
            self.n[0] = self.get_action()
