from math import log, exp
import numpy as np
import matplotlib.pyplot as plt
import warnings

def get_reservation_price(alpha, s, pre_w):
    return alpha*(s+1-pre_w)

def get_value_of_integral(alpha, r, w, pre_w):
    return (1/alpha)*(r-1)*(exp(alpha*(w-pre_w)) - 1) + w - pre_w

def get_w(alpha, p, r, pre_w):
    return (1/alpha)*log((p-1)/(r-1))+pre_w

def invert_phi(alpha, p, r, pre_w):
    return (1/alpha)*log((p-1)/(r-1))+pre_w

def get_phi_values(ps, alphas):
    # ps start with ps[0] > 1 and finish with M
    # for interval [p_{i-1}, p_i] we calculate r_i, w_i, and s_i := sum of integral until p_{i-1}
    rs = []
    ws = [0]
    ss = [0]
    for i, p in enumerate(ps):
        a = alphas[i]
        r_i = get_reservation_price(a, ss[-1], ws[-1])
        rs.append(r_i)
        pre_p = 1 if i == 0 else ps[i-1] 
        if r_i < pre_p: # r_i < p_{i-1}
            if i > 0:
                # we need to modify the last value of ss
                w_prime = max(0, invert_phi(alphas[i-1], ps[i-1], rs[-2], ws[-2]))
                # print(p, ws[-2], w_prime, ws[-1])
                assert(get_value_of_integral(alphas[i-1], rs[-2], w_prime, ws[-2]) >= 0)
                ss[-1] = ss[-2] + get_value_of_integral(alphas[i-1], rs[-2], w_prime, ws[-2]) + r_i * (ws[-1] - w_prime) 
        if r_i <= p:
            ws.append(get_w(a, p, r_i, ws[-1]))            
        else:
            ws.append(ws[-1])            
        ss.append(ss[-1] + get_value_of_integral(a, r_i, ws[-1], ws[-2]))
    if ws[-1] > 1:
        warnings.warn("The profile could not be followed")
    return rs, ws, ss

def plot(rs, ws, as_):
    def f(x, r, w, a, i):
        if i == 0:
            result = np.where((w[i] <= x) & (x <= w[i+1]), (r[i] - 1) * np.exp(a[i] * (x - w[i])) + 1, np.nan)
        else:
            result = np.where((w[i] <= x) & (x <= w[i+1]), (r[i] - 1) * np.exp(a[i] * (x - w[i])) + 1, np.nan)
        if i < len(r) - 1:
            result = np.where(result > r[i+1], r[i+1], result)
        return result

    # Plotting
    plt.figure(figsize=(8, 6))
    # ws_sorted = sorted(ws)
    ws_sorted = ws
    for i in range(len(rs)):
        x_values = np.linspace(ws_sorted[i], ws_sorted[i+1], 1000) if i != len(rs) - 1 else np.linspace(ws_sorted[i], ws_sorted[i]+1, 1000)
        y_values = f(x_values, rs, ws_sorted, as_, i)
        plt.plot(x_values, y_values, label=f'f{i+1}(x)', linewidth=2)

    plt.xlabel('w')
    plt.ylabel('phi(w)')
    plt.title('Plot of phi(w)')
    plt.legend()
    plt.grid(True)
    plt.show()

def do_binary_search(ps, alphas, alpha, error = 0.001, one_sided = False):
    '''
    if not onse-sided 3 price intervals [1,p1][p1,p2][p2,M] --> finds the best interval [d*a, a, d*a]
    else 2 price intervals [1,p1][p1,p2] --> finds the best interval [d*a, a]
    find the optimal d where the profile is `followable`
    '''
    # we know that d \in [a, M/a] cause a comp-ratio of M is always achievable
    alphas = np.array(alphas)
    M = ps[-1]
    h = 2
    l = 0
    while abs(h - l) > error:
        d = (h + l) / 2
        if one_sided:
            rs, ws, ss = get_phi_values(ps, [d*alpha, alpha]) #TODO: FIXME
        else:
            rs, ws, ss = get_phi_values(ps, alphas * d)
        if ws[-1] <= 1:
            h = d
        else:
            l = d
    return rs, ws, ss, d

def get_delta_2(ps, alphas, ind_cons, error=10**(-18)):
    '''
    ex:
    ps = [1,2,3,4,5,6] and ind_cons = 4 --> consistency is for [4,5]
    alphas go from [a1:[1,2],a2:[2,3],a3:[3,4],a4:[4,5]] --> we need to determine for [5,6] 
    '''
    h = ps[-1]
    l = 0
    d2 = (h+l)/2
    cons = alphas[-1]
    # print(ind_cons,alphas[ind_cons-1])
    x_cons = (ps[ind_cons-1]+ps[ind_cons])/2
    for ind in range(ind_cons,len(ps)-1):
        alphas.append(d2*((ps[ind] + ps[ind+1])/2 - x_cons) + cons)
    # print(alphas)
    num_its = 0
    while abs(h - l) > error:
        num_its += 1
        d2 = (h + l) / 2
        for ind in range(ind_cons,len(ps)-1):
            alphas[ind] = d2*((ps[ind] + ps[ind+1])/2 - x_cons) + cons
        # print(d2, alphas)
        rs, ws, ss = get_phi_values(ps[1:], alphas)
        if ws[-1] <= 1:
            h = d2
        else:
            l = d2
    return d2, num_its, l, h

def generate_price_intervals(M, num_intervals):
    return list(np.linspace(1, M, num_intervals + 1).astype(int))[1:]

def generate_lists_alphas(ps, ind_alpha, alpha, profile_type=1):
    # profile_type : 1 = linear, 2 = quadratic
    # ps start from p_1 > 1 and end in M
    if profile_type == 1:
        res = [alpha]*len(ps)
        for i in range(len(ps)):
            res[i] *= 1 + abs(i - ind_alpha)
        return res
    else:
        raise NotImplementedError
    
def generate_lists_alphas_2(ps, ind_alpha, cons, d1):
    alphas = []
    x_cons = (ps[ind_alpha-1]+ps[ind_alpha])/2
    for ind in range(ind_alpha):
        alphas.append(d1*((ps[ind] + ps[ind+1])/2 - x_cons) + cons)
    return alphas
    
def run(num_intervals, ind, alpha):
    ps = generate_price_intervals(100, num_intervals)
    als = generate_lists_alphas(ps, ind, alpha)
    return do_binary_search(ps, als, alpha), als

def run_d2(num_intervals, ind, alpha, d1):
    ps = generate_price_intervals(100, num_intervals)
    als = generate_lists_alphas_2(ps, ind, alpha, d1)
    # print(ps)
    # print(als)
    return get_delta_2(ps, als, ind)

# ps = [10, 15, 20, 35, 55, 75, 100]
# alphas = [4, 3.5, 3, 3.15, 3.25, 3.35, 3.5]
# rs, ws, ss = get_phi_values(ps, alphas)
# print(rs)
# print(ws)
# print(ss)
# plot(rs, ws, alphas)

# alpha = 3.8
# rs, ws, _, d = do_bijection([50,100], alpha, one_sided=True)
# print(rs, ws, d)
# plot(rs, ws, [d*alpha, alpha, d*alpha])

# (rs, ws, ss, d), als = run(10, 5, 6)
# print(rs,'\n', ws,'\n', d,'\n', als)
# plot(rs, ws, als)

# for d1 in np.arange(-5,0,0.5):
#     d2, n, l, h = run_d2(20, 3, 3.5, d1)
#     print(d1, d2, n, l, h)       