from math import log, exp
import numpy as np
import matplotlib.pyplot as plt
import warnings
import seaborn as sns

class Profiling():
    def __init__(self, ps, alphas):
        self.ps = ps
        self.alphas = alphas
        self.rs = []
        self.ws = [0]
        self.ss = [0]
        # phi = ["line", start_w, end_w, price, price, alpha]
        # phi = ["exp",  start_w, end_w, start_price, end_price, res_price, alpha]
        self.phi = []

    def reset(self):
        self.rs = []
        self.ws = [0]
        self.ss = [0]
        self.phi = []

    def get_reservation_price(self, alpha, s, pre_w):
        return alpha*(s+1-pre_w)

    def get_w(self, alpha, p, r, pre_w):
        try:
            val = (1/alpha)*log((p-1)/(r-1))+pre_w
        except:
            print("p", p)
            print("r", r)
            print("pre_w", pre_w)
            print("alpha", alpha)   
        return (1/alpha)*log((p-1)/(r-1))+pre_w
    
    def get_wprime(self, alpha, p, s, prew):
        return (p-alpha*(s-prew*p+1))/(alpha*(p-1))

    def get_value_of_integral(self, alpha, r, w, pre_w):
        return (1/alpha)*(r-1)*(exp(alpha*(w-pre_w)) - 1) + w - pre_w

    def get_profile(self):
        '''
        obtains the phi function and the utilization values that follow the profile F, which are defined by ws, rhos, ss
        if the profile is feasible: returns True
        if not: returns False  
        '''
        self.reset()
        self.phi = []
        for i, p in enumerate(self.ps):
            p = max(p,1)
            if p == 1:
                continue
            alpha = self.alphas[i]
            r_i = self.get_reservation_price(alpha, self.ss[-1], self.ws[-1])
            pre_p = 1 if i == 0 else self.ps[i-1] 
            if r_i < pre_p: 
                pre_w = self.ws[-1]
                w_prime = self.get_wprime(alpha, pre_p, self.ss[-1], pre_w)
                s_prime = self.ss[-1] + pre_p*(w_prime-pre_w)
                self.ws.append(w_prime)
                self.ss.append(s_prime)
                self.phi.append(["line", pre_w, w_prime, pre_p, pre_p, alpha])
                self.rs.append(self.get_reservation_price(alpha, s_prime, w_prime)) 
            else:
                self.rs.append(r_i) 
            self.ws.append(self.get_w(alpha, p, self.rs[-1], self.ws[-1]))
            self.ss.append(self.ss[-1] + self.get_value_of_integral(alpha, self.rs[-1], self.ws[-1], self.ws[-2]))
            self.phi.append(["exp", self.ws[-2], self.ws[-1], pre_p, p, self.rs[-1], alpha])
            if self.ws[-1] > 1:
                return False
        return True
        # if self.ws[-1] > 1:
        #     warnings.warn("The profile could not be followed")
    
    def plot(self, fig_name):
        # phi = ["line", start_w, end_w, price, price, alpha]
        # phi = ["exp",  start_w, end_w, start_price, end_price, res_price, alpha]
        def f(x, phi):
            if phi[0] == "exp":
                result = np.where((phi[1] <= x) & (x <= phi[2]), (phi[4] - 1) * np.exp(phi[6] * (x - phi[1])) + 1, np.nan)
            else:
                result = np.where((phi[1] <= x) & (x <= phi[2]), phi[3], np.nan)
            # else:
            #     result = np.where((w[i] <= x) & (x <= w[i+1]), (r[i] - 1) * np.exp(a[i] * (x - w[i])) + 1, np.nan)
            # if i < len(r) - 1:
            #     result = np.where(result > r[i+1], r[i+1], result)
            return result


        sns.set_theme()
        sns.set_context("paper")
        sns.set_style("white")
        # Plotting
        plt.figure(figsize=(8, 6))
        ws_names = ["w1"]
        ws_ticks = [0]
        ps_names = ["q1"] + [f"q{i}" for i in range(2, len(self.ps)+2)]
        ps_ticks = [1] + self.ps
        print(ps_ticks)
        i = 2
        for phi in self.phi:
            if phi[0] == "exp":
                ws_names.append(f"w{i}")
                # ps_names.append(f"q{i}")
                i += 1
                ws_ticks.append(phi[3])
                # ps_ticks.append(phi[1])
            if phi[0] == "line":
                ws_ticks[-1] = phi[3]
        for i, phi in enumerate(self.phi):
            # self.phi[i] : ["line/exp", res_price, start_w, end_w]
            x_values = np.linspace(phi[2], phi[3], 1000)
            y_values = f(x_values, phi)
            if phi[0] == "line":
                plt.plot(x_values, y_values, label=f'f{i+1}(x)', linewidth=2, color="darkblue")
            else:
                plt.plot(x_values, y_values, label=f'f{i+1}(x)', linewidth=2, color="darkred")
                
        plt.xticks(ws_ticks, ws_names)
        plt.yticks(ps_ticks, ps_names)
        plt.rcParams.update({'font.size': 22})
        plt.xlabel('utilization')
        plt.ylabel('reservation rate')
        # plt.legend()
        plt.grid(True)
        plt.savefig(fname=fig_name, format="pdf",dpi=200)

    def do_binary_search_3_intervals_symmetric(self, r, error = 0.0001):
        '''
        if not onse-sided 3 price intervals [1,p1][p1,p2][p2,M] --> finds the best c such that is followable [r, c, r]
        find the optimal d where the profile is `followable`
        '''
        # we know that d \in [a, M/a] cause a comp-ratio of M is always achievable
        h = r
        l = 1
        self.reset()
        self.alphas = [r, r, r]
        self.get_profile()
        if self.ws[-1] > 1: # r < r*
            warnings.warn("No c can be found")
            return None
        while abs(h - l) > error:
            c = (h + l) / 2
            self.alphas = [r, c, r]
            b = self.get_profile()
            if b:
                h = c
            else:
                l = c
        c = h
        self.alphas = [r, c, r]
        self.get_profile()
        return round(c, 3)
    
    def do_binary_search_3_intervals_asymmetric(self, r1, r2, error = 0.0001):
        '''
        if not onse-sided 3 price intervals [1,p1][p1,p2][p2,M] --> finds the best c such that is followable [r, c, r]
        find the optimal d where the profile is `followable`
        '''
        # we know that d \in [a, M/a] cause a comp-ratio of M is always achievable
        h = r2
        l = 1
        self.reset()
        self.alphas = [r1, r2, r2]
        self.get_profile()
        if self.ws[-1] > 1: # profile not followable
            warnings.warn("No c can be found")
            return None
        while abs(h - l) > error:
            c = (h + l) / 2
            self.alphas = [r1, c, r2]
            print("c", c)
            self.get_profile()
            if self.ws[-1] <= 1:
                h = c
            else:
                l = c
        return round(c, 3)


if __name__ == "__main__":
    # prices = [20,35,50,70,100]
    # alphas = [7,5,3,3.5,4]
    r = 4
    r2 = 3
    prices = [50,70,100]
    profile = Profiling(prices,None)
    print(profile.do_binary_search_3_intervals_symmetric(r),'\n')
    for l in profile.phi:
        print(l)
    # print(profile.do_binary_search_3_intervals_asymmetric(r, r2))
    # profile.get_profile()
    # for l in profile.phi:
        # print(l)
    # print(profile.phi, profile.ws)
    # profile.plot("profile1.pdf")