from oracle import Oracle
from environment import Environment
import numpy as np


class Algorithms:
    def __init__(self, T, N, k, p, delay_threshold, mean_return_delay, upper_bound_delay, skew, v_cv):
        self.T = T
        self.N = N  # Total number of items
        self.k = k  # Maximum number of items in the subset
        self.p = p  # Revenue of each item

        self.delay_threshold = delay_threshold
        self.mean_return_delay = mean_return_delay
        self.upper_bound_delay = upper_bound_delay
        self.skew = skew

        self.method = method
        self.v_cv = v_cv
        self.cv_oracle = Oracle(N=self.N, k=self.k, p=self.p, v=self.v_cv)

    def revenue(self, S):
        return self.cv_oracle.calculate_objective(S)

    def demba(self, c_1, c_2, c_3, c_4):
        t = 0
        e = 0

        hat_v_i_e = {i: 0 for i in range(self.N)}
        UCB = {i: 0 for i in range(self.N)}
        total_sales = {i: 0 for i in range(self.N)}
        total_revenue = 0
        E_i_e = {i: [] for i in range(self.N)}
        bar_v_i_e = {i: 0.5 for i in range(self.N)}
        bar_v_i_e[0] = 1
        S_e_bin = []

        v = [bar_v_i_e[i] for i in range(self.N)]

        env = Environment(item_utilities=self.v_cv, mean_delay=self.mean_return_delay,
                          delay_threshold=self.delay_threshold, upper_bound_delay=self.upper_bound_delay,
                          skew=self.skew)

        oracle = Oracle(N=self.N, k=self.k, p=self.p, v=v)

        while t < self.T:
            oracle.set_parameters(v, self.p)
            S_e = oracle.assortment()
            S_e_bin.append(S_e)

            rev = self.revenue(S_e)
            total_revenue += rev

            immediate_return, decisions = env.step(S_e)

            for i in range(self.N):
                if i in decisions:
                    total_sales[i] += 1
                    hat_v_i_e[i] += 1

            if immediate_return:
                for i in S_e:
                    if i != 0:
                        E_i_e[i].append(e)
                        UCB[i] = c_1 * np.sqrt((hat_v_i_e[i] * np.log(np.sqrt(self.N) * (e + 1))) / len(E_i_e[i])) \
                                 + c_2 * np.log(np.sqrt(self.N) * (e + 1)) / len(E_i_e[i]) \
                                 + c_3 * np.sqrt((hat_v_i_e[i] * np.log(self.N * (e + 1)) / len(E_i_e[i])))  \
                                 + c_4 * self.delay_threshold / len(E_i_e[i])

                        bar_v_i_e[i] = np.minimum(hat_v_i_e[i] / len(E_i_e[i]) + UCB[i], 1)

                e += 1
                v = [bar_v_i_e[i] for i in range(self.N)]

            t += 1

        return total_revenue, total_sales, E_i_e, S_e_bin

    def exp(self, exp_duration):
        t = 0
        e = 0

        hat_v_i_e = {i: 0 for i in range(self.N)}
        total_sales = {i: 0 for i in range(self.N)}
        total_revenue = 0
        E_i_e = {i: [] for i in range(self.N)}
        bar_v_i_e = {i: 0.5 for i in range(self.N)}
        bar_v_i_e[0] = 1
        S_e_bin = []

        v = [bar_v_i_e[i] for i in range(self.N)]

        env = Environment(item_utilities=self.v_cv, mean_delay=self.mean_return_delay,
                          delay_threshold=self.delay_threshold, upper_bound_delay=self.upper_bound_delay,
                          skew=self.skew)

        oracle = Oracle(N=self.N, k=self.k, p=self.p, v=v)

        while t < exp_duration:
            S_e = np.random.choice(range(2, self.N), self.k - 1, replace=False)
            S_e = np.insert(S_e, 0, 0)
            S_e = np.insert(S_e, 1, 1)
            S_e_bin.append(S_e)

            rev = self.revenue(S_e)
            total_revenue += rev

            immediate_return, decisions = env.step(S_e)

            for i in decisions:
                total_sales[i] += 1
                hat_v_i_e[i] += 1

            if immediate_return:
                for i in S_e[1:]:
                    E_i_e[i].append(e)

                e += 1
            t += 1

        for i in range(self.N):
            if len(E_i_e[i]) != 0:
                bar_v_i_e[i] = np.minimum(hat_v_i_e[i] / len(E_i_e[i]), 1)
            else:
                bar_v_i_e[i] = 0
        v = [bar_v_i_e[i] for i in range(self.N)]
        oracle.set_parameters(v, self.p)
        S_e = oracle.assortment()

        while t < self.T:
            S_e_bin.append(S_e)

            rev = self.revenue(S_e)
            total_revenue += rev

            immediate_return, decisions = env.step(S_e)

            for i in range(self.N):
                if i in decisions:
                    total_sales[i] += 1
                    hat_v_i_e[i] += 1

            if immediate_return:
                for i in S_e:
                    if i != 0:
                        E_i_e[i].append(e)

                e += 1
            t += 1

        return total_revenue, total_sales, E_i_e, S_e_bin

    def clairvoyant(self):
        S_star = self.cv_oracle.assortment()
        optimal_revenue = self.T * self.revenue(S_star)
        return optimal_revenue, S_star

    def demba_nothreshold(self, c_1, c_2, c_3):
        t = 0
        e = 0

        hat_v_i_e = {i: 0 for i in range(self.N)}
        UCB = {i: 0 for i in range(self.N)}
        total_sales = {i: 0 for i in range(self.N)}
        total_revenue = 0
        E_i_e = {i: [] for i in range(self.N)}
        bar_v_i_e = {i: 0.5 for i in range(self.N)}
        bar_v_i_e[0] = 1
        S_e_bin = []

        v = [bar_v_i_e[i] for i in range(self.N)]

        env = Environment(item_utilities=self.v_cv, mean_delay=self.mean_return_delay,
                          delay_threshold=self.delay_threshold, upper_bound_delay=self.upper_bound_delay,
                          skew=self.skew)

        oracle = Oracle(N=self.N, k=self.k, p=self.p, v=v)

        while t < self.T:
            oracle.set_parameters(v, self.p)
            S_e = oracle.assortment()
            S_e_bin.append(S_e)

            rev = self.revenue(S_e)
            total_revenue += rev

            immediate_return, decisions = env.step(S_e)

            for i in range(self.N):
                if i in decisions:
                    total_sales[i] += 1
                    hat_v_i_e[i] += 1

            if immediate_return:
                for i in S_e:
                    if i != 0:
                        E_i_e[i].append(e)
                        UCB[i] = c_1 * np.sqrt((hat_v_i_e[i] * np.log(np.sqrt(self.N) * (e + 1))) / len(E_i_e[i])) \
                                 + c_2 * np.log(np.sqrt(self.N) * (e + 1)) / len(E_i_e[i]) \
                                 + c_3 * (self.mean_return_delay / len(E_i_e[i]))

                        bar_v_i_e[i] = np.minimum(hat_v_i_e[i] / len(E_i_e[i]) + UCB[i], 1)

                e += 1
                v = [bar_v_i_e[i] for i in range(self.N)]

            t += 1

        return total_revenue, total_sales, E_i_e, S_e_bin
