import numpy as np
from onlinealgo import OnlineAlgo, CannotCoverException
import scipy


class MultiplicativeWeightsAlgo(OnlineAlgo):

    def __init__(self, sc_input, prediction=None):
        super().__init__(sc_input)
        if prediction is None:
            prediction = np.full(sc_input.num_sets, True)
        self.prediction = prediction.copy()
        self.connections = sc_input.connections * scipy.sparse.csr_array(self.prediction)

    def request_element(self, elem, pen=np.inf, simulate=False):
        solution = self.solution
        cur_conn_row = self.connections[[elem], :].toarray()[0, :]
        if cur_conn_row.sum() == 0:
            raise CannotCoverException('no set containing element')
        scalar = self.sc_input.set_prices.min()
        mul_row = 1 + scalar / (self.sc_input.set_prices)
        pen_cost = 0
        while cur_conn_row.dot(solution) < 1 and pen_cost < pen:
            cur_mul_row = mul_row.copy()
            cur_mul_row[np.logical_not(cur_conn_row)] = 1
            cur_add_row = (scalar / (cur_conn_row.sum() * self.sc_input.set_prices))
            cur_add_row[np.logical_not(cur_conn_row)] = 0
            solution = solution * cur_mul_row + cur_add_row
            pen_cost += scalar

        if not simulate:
            self.solution = solution
        return solution
