from online_sc import MultiplicativeWeightsAlgo, CannotCoverException
import numpy as np
from onlinealgo import OnlineAlgo


class StandardOnlineCombiner(OnlineAlgo):

    def __init__(self, sc_input, prediction):
        super().__init__(sc_input)
        self.gen_algo = MultiplicativeWeightsAlgo(sc_input)
        self.pred_algo = MultiplicativeWeightsAlgo(sc_input, prediction=prediction)
        self.total_pen = 0

        self.algos = [self.gen_algo, self.pred_algo]
        self.cur_algo_index = 0
        self.cur_cost_cap = self.sc_input.set_prices.min()

    def request_element(self, elem):
        could_not_cover = set()
        while True:
            cur_algo = self.algos[self.cur_algo_index]
            if len(could_not_cover) == len(self.algos):
                raise CannotCoverException('no set containing element')
            try:
                solution = cur_algo.request_element(elem, simulate=True)
                has_covered = True
            except CannotCoverException:
                has_covered = False
                could_not_cover.add(self.cur_algo_index)
                solution = None

            if (not has_covered) or np.maximum(self.solution, solution).dot(
                    self.sc_input.set_prices) > self.cur_cost_cap:
                self.cur_cost_cap *= 2
                self.cur_algo_index = (self.cur_algo_index + 1) % len(self.algos)
                continue

            cur_algo.request_element(elem, simulate=False)
            self.solution = np.maximum(self.solution, cur_algo.solution)
            return
