from gurobipy import Model, GRB, quicksum
import gurobipy as gp
from itertools import combinations
import numpy as np


class Oracle:
    def __init__(self, N, k, p, v, method='greedy'):
        self.N = N
        self.k = k
        self.coef = np.array(p)
        self.v = v
        self.method = method

    def calculate_objective(self, S):
        # the first and second elements are always included
        S = list(S)
        sum_v = sum(self.v[i] for i in S)

        return sum(self.coef[i] * self.v[i] / sum_v for i in S)

    def find_optimal_subset_mip(self):
        env = gp.Env(empty=True)
        env.setParam("OutputFlag", 0)
        env.start()

        m = Model('assortment_optimization', env=env)
        m.setParam('NonConvex', 2)
        # m.Params.LogToConsole = 0

        x = m.addVars(self.N, vtype=GRB.BINARY, name="x")
        t = m.addVar(name="t")

        m.setObjective(quicksum(t * self.coef[i] * self.v[i] * x[i] for i in range(self.N)), GRB.MAXIMIZE)

        m.addConstr(t * quicksum(self.v[i] * x[i] for i in range(self.N)) == 1, "transformation")
        m.addConstr(quicksum(self.v[i] * x[i] for i in range(self.N)) >= 1, "Avoid_division_by_zero")
        m.addConstr(quicksum(x[i] for i in range(self.N)) <= self.k, "Cardinality")
        m.addConstr(x[0] == 1, "include_first")
        m.addConstr(x[1] == 1, "include_second")

        m.optimize()

        optimal_subset = [i for i in range(self.N) if x[i].X > 0.5]
        # optimal_value = m.ObjVal

        return optimal_subset

    def assortment(self):
        if self.method == 'mip':
            return self.find_optimal_subset_mip()
        else:
            raise ValueError("Invalid method")

    def set_parameters(self, v, p):
        self.v = v
        self.coef = np.array(p)

