import numpy as np
import random
from scipy.optimize import minimize
from scipy.optimize import linprog

import pulp
from itertools import combinations, chain


###################################
### Revenue-Maximizing Solution ###
###################################

def rev_max_sol(w, p, r, K):
    """
    Return the revenue maximizing solution
    """

    N = len(r) # number of items
    M = len(p) # number of users

    # Generate all valid assortments (combinations of up to K items)
    all_assortments = list(chain.from_iterable(combinations(range(N), k) for k in range(1, K + 1)))

    # Create a linear programming problem
    problem = pulp.LpProblem("RevenueMax", pulp.LpMaximize)

    # Decision variables using pulp.LpVariable.dicts for easier access
    x = {}
    for j in range(M):
        x[j] = pulp.LpVariable.dicts(f"x_{j}", all_assortments, lowBound=0, upBound=1, cat='Continuous')

    # Objective function
    objective = pulp.lpSum([p[j] * pulp.lpSum([x[j][assortment] * sum(r[i] * w[i][j] for i in assortment) / (1 + sum(w[i][j] for i in assortment)) \
                                               for assortment in all_assortments]) for j in range(M)])
    problem += objective

    # Constraint: sum of x_S,j for each j = 1
    for j in range(M):
        problem += pulp.lpSum([x[j][assortment] for assortment in all_assortments]) == 1

    # Solve the problem
    problem.solve()

    # Extract the solution in the desired format
    x_solution = {j: {S: pulp.value(x[j][S]) for S in all_assortments} for j in range(M)}

    optimal_revenue = pulp.value(problem.objective)

    return x_solution, optimal_revenue

def generate_all_assortments(num_items, min_size, max_size):
    """
    Generate all possible assortments of up to max_size items.
    """
    assortments = []
    for r in range(min_size, max_size+1):
        assortments.extend(combinations(range(num_items), r))
    return assortments

def random_assortments(w, K):
    """
    Generate an assortment uniformly at random
    """
    num_items, num_users = w.shape
    all_assortments = generate_all_assortments(num_items, 1, K)

    random_prob = 1/len(all_assortments)

    return {j: {assortment: random_prob for assortment in all_assortments} for j in range(num_users)}

def random_sizeK_assortments(w, K, all_user_type=False):
    """
    Generate an assortment of size exactly K uniformly at random
    """
    num_items, num_users = w.shape
    all_assortments = generate_all_assortments(num_items, K, K)

    random_prob = 1/len(all_assortments)

    if all_user_type:
        return {j: {assortment: random_prob for assortment in all_assortments} for j in range(num_users)}
    else:
        return {assortment: random_prob for assortment in all_assortments}

def single_assortment(w, A, all_user_type=False):
    """
    Always offer a single assortment A
    """
    num_items, num_users = w.shape

    A = tuple(sorted(A))

    if all_user_type:
        return {j: {A: 1.0} for j in range(num_users)}
    else:
        return {A: 1.0}


##########################
### User-Fair Solution ###
##########################

def compute_user_fair(w, K):
    """
    The user-fair solution is the solution that generates the highest utilities;
    i.e. to show the K items with the highest popularity weights
    """

    num_users = w.shape[1]
    assortments = generate_all_assortments(w.shape[0], 1, K)
    optimal_assortments = {j: {assortment: 0 for assortment in assortments} for j in range(num_users)}

    for j in range(num_users):
        # Get indices of the top K items for user type j
        top_k_indices = np.argsort(w[:, j])[-K:]
        top_k_indices_set = set(top_k_indices)

        # Find the assortment that matches the top K items exactly
        for assortment in assortments:
            if set(assortment) == top_k_indices_set:
                optimal_assortments[j][assortment] = 1
                break

    return optimal_assortments

##########################
### Item-Fair Solution ###
##########################

def compute_item_maxmin_fair(w, p, r, K, obj="revenue"):
    """
    Compute the item-fair solution under maxmin fairness
    """
    N = len(r)
    M = len(p)

    # Initialize the optimization problem
    problem = pulp.LpProblem("MaxMin-Item-Fairness", pulp.LpMaximize)

    # Correctly generate all valid assortments (combinations of up to K items)
    all_assortments = [tuple(comb) for k in range(1, K + 1) for comb in combinations(range(N), k)]

    # Decision variables
    x = {}
    for j in range(M):
        x[j] = pulp.LpVariable.dicts("x_{}".format(j), [S for S in all_assortments], lowBound=0, upBound=1, cat='Continuous')

    # Variable to maximize
    z = pulp.LpVariable("z", lowBound=0)

    # Objective function
    problem += z, "Objective"

    # Constraints for outcomes for each item
    if obj == "revenue":
        for i in range(N):
            problem += pulp.lpSum(p[j] * pulp.lpSum(x[j][S] * r[i] * w[i][j] / (1 + sum(w[k][j] for k in S)) \
                                                    for S in all_assortments if i in S) \
                                  for j in range(M)) >= z, \
                       f"OutcomeConstraint_{i}"
    elif obj == "visibility":
        for i in range(N):
            problem += pulp.lpSum(p[j] * pulp.lpSum(x[j][S] for S in all_assortments if i in S) \
                                  for j in range(M)) >= z, \
                       f"OutcomeConstraint_{i}"
    else:
        raise ValueError("Item outcome not valid.")

    # Probability distribution constraint
    for j in range(M):
        problem += pulp.lpSum(x[j][S] for S in all_assortments) <= 1, f"ProbDistConstraint_{j}"

    # Solve the problem
    solver = pulp.PULP_CBC_CMD(msg=False)  # CBC solver with no messages
    problem.solve(solver)

    # Extract the solution in the desired format
    x_solution = {j: {S: pulp.value(x[j][S]) for S in all_assortments} for j in range(M)}

    return x_solution

def compute_item_KS_fair(w, p, r, K, obj="revenue"):
    """
    Compute the item-fair solution under K-S fairness
    """
    N = len(r)
    M = len(p)

    # Initialize the optimization problem
    problem = pulp.LpProblem("KS-Item-Fairness", pulp.LpMaximize)

    # Correctly generate all valid assortments (combinations of up to K items)
    all_assortments = [tuple(comb) for k in range(1, K + 1) for comb in combinations(range(N), k)]

    # Decision variables
    x = {}
    for j in range(M):
        x[j] = pulp.LpVariable.dicts("x_{}".format(j), [S for S in all_assortments], lowBound=0, upBound=1, cat='Continuous')

    # Variable to maximize
    z = pulp.LpVariable("z", lowBound=0)

    # Objective function
    problem += z, "Objective"

    # Constraints for outcomes for each item
    if obj == "revenue":
        for i in range(N):
            problem += pulp.lpSum(p[j] * pulp.lpSum(x[j][S] * r[i] * w[i][j] / (1 + sum(w[k][j] for k in S)) \
                                                    for S in all_assortments if i in S) for j in range(M)) \
                    >= z * pulp.lpSum(p[j] * r[i] * w[i][j]/ (1 + w[i][j]) for j in range(M)) , \
                       f"OutcomeConstraint_{i}"
    elif obj == "visibility":
        for i in range(N):
            problem += pulp.lpSum(p[j] * pulp.lpSum(x[j][S] for S in all_assortments if i in S) \
                                  for j in range(M)) >= z, \
                       f"OutcomeConstraint_{i}"
    else:
        raise ValueError("Item outcome not valid.")

    # Probability distribution constraint
    for j in range(M):
        problem += pulp.lpSum(x[j][S] for S in all_assortments) <= 1, f"ProbDistConstraint_{j}"

    # Solve the problem
    solver = pulp.PULP_CBC_CMD(msg=False)  # CBC solver with no messages
    problem.solve(solver)

    # Extract the solution in the desired format
    x_solution = {j: {S: pulp.value(x[j][S]) for S in all_assortments if pulp.value(x[j][S]) > 0} for j in range(M)}

    return x_solution

##################################
### Get Item and User Outcomes ###
##################################

def compute_item_outcomes(x, w, p, r, obj="revenue"):
    """
    Return the outcome received by each item, as a list
    """
    N = len(r)  # Number of items
    M = len(p)  # Number of user types
    outcomes = [0] * N  # Initialize outcomes for each item

    if obj == "revenue":
        # Iterate over each item to calculate its outcome
        for i in range(N):
            for j in range(M):
                for S, prob_value in x[j].items():
                    # prob_value = prob.varValue
                    # Check if item i is in assortment S
                    if i in S:
                        # Calculate the contribution of this assortment to the item's outcome
                        contribution = prob_value * r[i] * w[i][j] / (1 + sum(w[k][j] for k in S)) * p[j]
                        outcomes[i] += contribution
    elif obj == "visibility":
        # Iterate over each item to calculate its outcome
        for i in range(N):
            for j in range(M):
                for S, prob_value in x[j].items():
                    # prob_value = prob.varValue
                    # Check if item i is in assortment S
                    if i in S:
                        # Calculate the contribution of this assortment to the item's outcome
                        contribution = prob_value * p[j]
                        outcomes[i] += contribution
    else:
        raise ValueError("Item outcome not valid.")

    return outcomes

def compute_user_outcomes(x, w):
    """
    Return the outcome received by each user, as a list
    """

    M = len(x)  # Number of user types, assuming x is structured as {j: {S: probability, ...}, ...}
    user_outcomes = [0] * M

    for j in range(M):
        U_j = 0  # Initialize utility for user type j
        for S, prob_value in x[j].items():

            if prob_value > 0:

                # Update utility U_j for user type j
                user_outcomes[j] += prob_value * np.log(1 + sum([w[i][j] for i in S]))

    return user_outcomes

def calculate_expected_revenue(x, w, p, r):
    """
    Compute the expected revenue received by the platform
    """
    expected_revenue = 0
    num_items = len(r)

    for j, assortments in x.items():
        for S, prob_S_j in assortments.items():
            expected_revenue += \
                p[j] * prob_S_j * sum([r[i] * w[i,j] / (1 + sum(w[i,j] for i in S)) \
                                        if i in S else 0 for i in range(num_items)])

    return expected_revenue

##################################
###### Solve Problem FAIR ########
##################################

def solve_fair_recommendation_problem(w, p, r, item_fair_outcomes, user_fair_outcomes, K, delta_item=0, delta_user=0, eta=0, item_obj="revenue"):
    """
    Solve the fair recommendation problem, given fair outcomes for items and users
    """
    N = len(r)  # Number of items
    M = len(p)  # Number of user types

    # Generate all valid assortments (combinations of up to K items)
    all_assortments = list(chain.from_iterable(combinations(range(N), k) for k in range(1, K + 1)))

    # Initialize the optimization problem
    problem = pulp.LpProblem("FairRecommendation", pulp.LpMaximize)

    # Decision variables using pulp.LpVariable.dicts for easier access
    x = {}
    for j in range(M):
        x[j] = pulp.LpVariable.dicts(f"x_{j}", all_assortments, lowBound=0, upBound=1, cat='Continuous')

    # Objective: Maximize expected revenue
    revenue = pulp.lpSum(p[j] * x[j][S] * sum(r[i] * w[i][j] / (1 + sum(w[i][j] for i in S)) for i in S)
                         for j in range(M) for S in all_assortments)
    problem += revenue

    # Item-Fair Constraints
    # Note: here we care about the item outcomes
    if item_obj == "revenue":
        for i in range(N):
            problem += pulp.lpSum(p[j] * x[j][S] * r[i] * w[i][j] / (1 + sum(w[i][j] for i in S))
                                  for j in range(M) for S in all_assortments if i in S) >= delta_item * item_fair_outcomes[i] - eta, \
                                  f"ItemFairConstraint_{i}"
    elif item_obj == "visibility":
        for i in range(N):
            problem += pulp.lpSum(p[j] * x[j][S] for j in range(M) for S in all_assortments if i in S) >= delta_item * item_fair_outcomes[i] - eta, \
                                  f"ItemFairConstraint_{i}"
    else:
        raise ValueError("Item outcome not valid.")

    # User-Fair Constraints
    for j in range(M):
        problem += pulp.lpSum(x[j][S] * np.log(1 + sum(w[i][j] for i in S)) for S in all_assortments) >= delta_user * user_fair_outcomes[j] - eta, \
                              f"UserFairConstraint_{j}"

    # Prob. Dist. Constraints: Ensure x[j][S] forms a probability distribution for each user type
    for j in range(M):
        problem += pulp.lpSum(x[j][S] for S in all_assortments) == 1, f"ProbDist_{j}"

    # Solve the problem
    solver = pulp.PULP_CBC_CMD(msg=False)  # CBC solver with no messages
    problem.solve(solver)

    # Extract the solution in the desired format
    x_solution = {j: {S: pulp.value(x[j][S]) for S in all_assortments} for j in range(M)}

    status = pulp.LpStatus[problem.status]

    optimal_revenue = pulp.value(problem.objective)

    return status, x_solution, optimal_revenue

def draw_purchase_decision(S, w, j):
    """
    Draw a purchase decision for a given assortment S and user type j based on the weights w.

    Parameters:
    - S: The offered assortment (a tuple of item indices).
    - w: The weights for each item for each user type (a 2D list or array where w[i][j] is the weight of item i for user type j).
    - j: The user type.

    Returns:
    - The index of the purchased item, or None if no purchase is made.
    """
    denominator = 1 + sum(w[i][j] for i in S)

    # Probabilities for each item in the assortment
    item_probabilities = [w[i][j] / denominator for i in S]

    # Probability of making no purchase
    no_purchase_probability = 1 / denominator

    # Combine item probabilities with the no purchase probability
    probabilities = item_probabilities + [no_purchase_probability]

    # Draw a purchase decision based on the probabilities
    decision = random.choices(S + (-1,), weights=probabilities, k=1)[0]

    return decision
