import numpy as np
import math

def get_height(T):
    ''' Computes the height for the tree in the standard binary mechanism,
    the one only using left-subtrees and never the root, that supports 'T' inputs.

    Parameters:
    T: Number of inputs to support.

    Returns:
    The height of the tree.
    '''
    return math.ceil(math.log2(T+1))

def compute_total_number_used_for_releasing_m_outputs(m):
    ''' Computes the total number of nodes added together when
    releasing 'm' prefix sums with the standard binary mechanism.

    Inefficient, but correct and not a bottleneck.

    Parameters:
    m: Number of outputs

    Returns:
    Total number of nodes used in relasing the outputs
    '''

    n = 0
    for k in range(1, m+1):
        n += np.binary_repr(k).count('1')

    return n

def compute_epsilons_from_mean_squared_error(target_mse, privacy_ratio, w, T):
    ''' Computes 'epsilon_per_node' and 'epsilon_for_past'.

    Parameters:
    target_mse: The target mean-squared error we want to achieve over all 'T' inputs
    privacy_ratio: The ratio between the privacy spent for the tree in each round, and the privacy spent on releasing
        the sum of all inputs from past rounds. Defaults to 'None', and then computes the ratio such that the privacy
        loss at the last 'd' in 'd_vec' is minimal.
    w: Length of each round.
    T: Length of the stream.

    The privacy ratio enforces that
    'epsilon_tree * privacy_ratio = tree_height * epsilon_per_node * privacy_ratio = epsilon_for_past'
    '''

    # Compute how many nodes are added in total
    num_full_rounds = math.floor(T / w)
    total_num_nodes_used = compute_total_number_used_for_releasing_m_outputs(w) * num_full_rounds
    if T % w != 0:
        last_round_length = T % w
        total_num_nodes_used += compute_total_number_used_for_releasing_m_outputs(last_round_length)

    target_squared_l2_error = target_mse * T

    # Explicitly compute the epsilons necessary to achieve the given squared l2-error
    h = get_height(w)
    epsilon_per_node = math.sqrt(2 *
        (total_num_nodes_used + (T-w) * math.pow(privacy_ratio * h, -2)) / target_squared_l2_error)
    epsilon_for_past = privacy_ratio * h * epsilon_per_node
    
    # Unnecesary, but a sanity check
    assert abs( 2 * (math.pow(epsilon_per_node, -2) * total_num_nodes_used + math.pow(epsilon_for_past, -2) * (T-w)) / T - target_mse ) < 1e-9, "Math wrong.."

    print(f'We use epsilon_past={epsilon_for_past}, epsilon_tree={epsilon_per_node * h} and for w={w}, T={T} to achieve MSE={target_mse}')

    return epsilon_per_node, epsilon_for_past

def compute_optimal_privacy_ratio_for_l2_error(w, d, T):
    ''' Computes the optimal split of the privacy budget between what is spent on releasing the tree
    in each round, and what is spent on releasing past inputs.

    Maximizing the privacy at 'd' under the constraint of achieving a
    fix l2-error, yields the privacy_ratio.

    NOTE: Assumes that T > d > w.
    
    Under this assumption, the privacy loss after 'd' is:
    
        'tree_height * epsilon_per_node + future_rounds * epsilon_for_past'

    NOTE: equivalent to say 'epsilon_tree * privacy_ratio = epsilon_for_past', since for the
        conventional binary mechanism, we add noise Lap(1/epsilon_per_node) = Lap(h/epsilon_tree).

    Parameters:
    w: Length of each round.
    d: The value of 'd' at which we want the privacy loss minimized
    T: Length of the stream (up to which we normalize)

    Returns:
    'privacy_ratio' where enforcing: h * epsilon_per_node * privacy_ratio = epsilon_for_past
        gives the best possible privacy for an input 'd' steps back from a release.

    '''
    assert T > d, "Need T > d"
    assert d > w, "Only meaningful to compute if more than one round is involved"

    h = get_height(w)
    N = math.ceil(d / w) - 1 # This computation might change depending on what happens in the first round.
    M = compute_total_number_used_for_releasing_m_outputs(T)
    alpha = (T - w) / M

    # The optimal privacy ratio is the real, positive root of this polynomial
    privacy_ratio = np.roots([math.pow(N*h, 2), N*h, 0, -alpha * N, -alpha])[0]
    assert np.imag(privacy_ratio) == 0
    assert np.real(privacy_ratio) > 0

    privacy_ratio = np.real(privacy_ratio)
    # Print here what privacy ratio we use
    print(f"For W={w}, T={T} and d={d} to optimize at, we get privacy_ratio=epsilon_past/epsilon_tree={privacy_ratio}")

    return privacy_ratio

def privacy_loss_for_t_at_tau(t, tau, epsilon_per_node):
    ''' Computes the privacy for an input at 't' evaluated after outputting at 'tau'
    for the standard binary mechanism that only uses left-subtrees and never the root.

    We compute the privacy loss for
    an input at time 't', evaluated later at time 'tau',
    for a given 'epsilon_per_node'.

    Parameters:
    t: Time step for the input x_t whose privacy we are computing
    tau: Time step at which we are evaluating the privacy
    epsilon_per_node: Describes the noise we add to each node in tree, z~Lap(1/epsilon_per_node)

    Returns:
    The privacy for the input arriving at time 't', evaluated at time 'tau'
    '''
    assert t > 0
    assert t <= tau

    h = get_height(tau)

    # Path to leaf storing x_t
    t_path = np.binary_repr(t-1, width=h)
    # Path to leaf that would store x_{T+1}, used to compute prefix sum up to T
    tau_path = np.binary_repr(tau, width=h)

    # Find common prefix
    i = -1
    while t_path[i+1] == tau_path[i+1]:
        i += 1

    # The number of 0s in the path to x_t gives the number of left-child ancestors
    # Not all of these left-children in the tree are used for an output if T != 2^h - 1
    # Only the ones diverging from the path to the last leaf will be used
    # The number of left-children ancestors times the privacy parameter
    # gives the privacy cost of changing input at t.
    return epsilon_per_node * t_path[i+1:].count('0')

def privacy_loss_for_t_at_tau_with_rounds(t, tau, epsilon_per_node, epsilon_for_past, w):
    ''' Computes the privacy loss for an input at 't' w.r.t. _ONLY_ releasing
    at 'tau' when using the standard binary mechanism, with resets every 'w' steps.

    Describes the process of computing prefix sums by computing them at rounds of
    length 'w'. When outputting the prefix sums of a round,
    we first release the sum of all inputs from past rounds by adding noize z~Lap(1/epsilon_for_past)

    Parameters:
    t: Time step for the input x_t whose privacy we are computing
    tau: Time step at which we are evaluating the privacy
    epsilon_per_node: Describes the noise we add to each node in tree, z~Lap(1/epsilon_per_node)
    epsilon_for_past: Describes the noise we add the sum of all inputs from past rounds, z~Lap(1/epsilon_for_past)
    w: Length of each round

    Returns:
    The privacy loss for the input arriving at time 't', evaluated at time 'tau'
    '''
    assert t > 0
    assert t <= tau

    privacy_loss = 0

    # If tau < w, then the privacy guarantee is equal to that of running w/o resets
    if tau <= w:
        privacy_loss += privacy_loss_for_t_at_tau(t, tau, epsilon_per_node)
        return privacy_loss
    
    # From this point onwards, we know tau > w, i.e. there are >1 rounds
    total_rounds = math.ceil(tau / w)
    first_round = math.ceil(t / w)

    # First, we consider the privacy cost from releasing x_t in its first round
    relative_t_in_first_round = ((t-1) % w) + 1

    # Check if the full tree that x_t participates in is released
    if total_rounds == first_round:
        relative_tau_in_first_round = ((tau-1) % w) + 1
        privacy_loss += privacy_loss_for_t_at_tau(relative_t_in_first_round, relative_tau_in_first_round, epsilon_per_node)
    else:
        privacy_loss += privacy_loss_for_t_at_tau(relative_t_in_first_round, w, epsilon_per_node)

    # Next we compute the cost from releasing it in all future rounds
    later_rounds_participated = total_rounds - first_round
    privacy_loss += later_rounds_participated * epsilon_for_past

    return privacy_loss

def worst_case_privacy_loss_after_d_steps(d, epsilon_per_node, epsilon_for_past, w, T):
    ''' Computes the worst-case privacy loss for an input at time 't',
    when considering all releases up to and including time 'tau = t+d',
    taken over all valid 't'.

    The privacy guarantee for a given 't' will depend on its relative position in the
    round it participates in. There are 'w' such positions, so we return the maximum 
    over all such positions"

    Parameters:
    d: Time after which to evaluate the privacy of an input
    epsilon_per_node: The epsilon parameter used for generating noise for
        each node in the tree, z~Lap(1/epsilon_per_node)
    w: Length of each round
    T: Length of the stream

    Returns:
    The worst-case privacy for an input, evaluated at the end of
        the time step 'd' steps into the future.
    '''

    # Consider all inputs positions 't' in the first round, that satisfy 't+d <= T'
    return max( privacy_loss_for_t_at_tau_with_rounds(t, min(t + d, T), epsilon_per_node, epsilon_for_past, w) for t in range(1, 1 + min(w, T-d)) )

def empirical_privacy_loss(d_vec, T, target_mse, w, privacy_ratio=None):
    ''' Computes the privacy loss for this baseline for every 'd' in 'd_vec'.

    Parameters:
    d_vec: Array of values 'd' for which to compute the privacy expiration.
    T: Length of the stream.
    target_mse: The target mean-squared error we want to achieve over all 'T' inputs
    w: Length of each round.
    privacy_ratio: The ratio between the privacy spent for the tree in each round, and the privacy spent on releasing
        the sum of all inputs from past rounds. Defaults to 'None', and then computes the ratio such that the privacy
        expiration at the last 'd' in 'd_vec' is minimal (i.e. the privacy is maximal).

    Returns:
    An array of values 'y', such that 'y[i]' is the privacy loss at 'd_vec[i]'.
    '''

    assert max(d_vec) < T, "Need d < T"

    # Compute the best choice of privacy parameters, in the sense of minimizing the privacy loss at the last 'd'
    # Unless we explicitly provide a ratio
    if privacy_ratio == None:
        privacy_ratio = compute_optimal_privacy_ratio_for_l2_error(w, max(d_vec), T)
        print(privacy_ratio)
    
    # Compute the privacy parameters to use based on a target average variance (expected L2-squared error)
    epsilon_per_node, epsilon_for_past = compute_epsilons_from_mean_squared_error(
        target_mse=target_mse, privacy_ratio=privacy_ratio, w=w, T=T)

    # Compute the worst-case privacy 'd' steps back from a release
    privacy_d_steps_back = np.array(
        [worst_case_privacy_loss_after_d_steps(d, epsilon_per_node=epsilon_per_node, epsilon_for_past=epsilon_for_past, w=w, T=T) for d in d_vec])

    # The running-max over the worst-case privacy gives the worst-case privacy loss
    return np.maximum.accumulate(privacy_d_steps_back)