from online_sc import MultiplicativeWeightsAlgo, CannotCoverException
import numpy as np
from onlinealgo import OnlineAlgo


class SmoothOnlineCombiner(OnlineAlgo):
    def __init__(self, sc_input, prediction):
        super().__init__(sc_input)
        self.gen_algo = MultiplicativeWeightsAlgo(sc_input)
        self.pred_algo = MultiplicativeWeightsAlgo(sc_input, prediction=prediction)
        self.total_pen = 0

    def request_element(self, elem):
        # first: find a low-enough value for pen.
        cur_elem_row = self.sc_input.connections[[elem], :].toarray()[0, :]
        init_sat = self.solution.dot(cur_elem_row)
        if init_sat >= 1:
            return
        remaining_sat = 1 - init_sat
        init_pen = self.sc_input.set_prices.min() * remaining_sat

        pen = init_pen
        while True:
            cur_elem_row = self.sc_input.connections[[elem], :].toarray()[0, :]
            elem_in_pred_algo = True
            elem_in_gen_algo = True
            try:
                cur_gen_sol = self.gen_algo.request_element(elem, pen, simulate=True)
            except CannotCoverException:
                # if element cannot be covered, give a bad solution so other algo has to pick up
                cur_gen_sol = np.zeros(self.sc_input.num_sets)
                elem_in_gen_algo = False

            try:
                cur_pred_sol = self.pred_algo.request_element(elem, pen, simulate=True)
            except CannotCoverException:
                cur_pred_sol = np.zeros(self.sc_input.num_sets)
                elem_in_pred_algo = False

            if not elem_in_gen_algo and not elem_in_pred_algo:
                raise CannotCoverException('no set containing element')

            if np.maximum(cur_gen_sol, cur_pred_sol).dot(cur_elem_row) >= 1:
                break
            pen *= 1.1

        try:
            self.gen_algo.request_element(elem, pen, simulate=False)
        except CannotCoverException:
            pass
        try:
            self.pred_algo.request_element(elem, pen, simulate=False)
        except CannotCoverException:
            pass

        self.total_pen += pen

        self.solution = np.maximum(self.gen_algo.solution, self.pred_algo.solution)
