import numpy as np
import matplotlib.pyplot as plt

def inlier_distribution(sample_size, alpha, epsilon, k, setting):
    if setting == 'LTC':
        gamma = (alpha / epsilon) ** (1 / k)
    elif setting == 'CTL':
        gamma = alpha ** (1 / k)
    else:
        raise ValueError("Invalid setting. Choose 'LTC' or 'CTL'.")
    
    P_over_gamma = 0.5 * (gamma ** k)
    P_neg_gamma = 0.5 * (gamma ** k)
    P_X_0 = 1 - (gamma ** k)
    probabilities = [P_over_gamma, P_neg_gamma, P_X_0]
    values = [1/gamma, -1/gamma, 0]
    # Initialize the inlier distribution array
    samples = np.random.choice(values, size = sample_size, p = probabilities)
    return samples

def compute_M(epsilon, alpha, k, n, delta, setting):
    if setting == 'LTC':
        M = min((epsilon / alpha) ** (1 / k) * (1/3), (epsilon * np.sqrt(n) / np.sqrt(np.log(1 / delta))) ** (1 / k))
    elif setting == 'CTL':
        M = min((1 / alpha) ** (1 / k), (epsilon * np.sqrt(n) / np.sqrt(np.log(1 / delta))) ** (1 / k))
    else:
        raise ValueError("Invalid setting. Choose 'LTC' or 'CTL'.")
    return M

def epsilon_ldp_mechanism(U, M, epsilon):
    U_tilde = np.where(np.abs(U) > M, 0, U)
    p_positive = (1 + U_tilde/M) / 2
    U_prime = np.where(np.random.rand(len(U)) < p_positive, M, -M)
    
    p_response = np.exp(epsilon) / (np.exp(epsilon) + 1)
    scale_response = (np.exp(epsilon) + 1) / (np.exp(epsilon) - 1)
    U_tilde_prime = np.where(np.random.rand(len(U)) < p_response, scale_response * U_prime, -scale_response * U_prime)
    return U_tilde_prime

def analyzer(Z, M, epsilon):
    scale_response = (np.exp(epsilon) + 1) / (np.exp(epsilon) - 1)
    Z_clipped = np.where(np.abs(Z) > M * scale_response, 0, Z)
    estimator = np.mean(Z_clipped)
    return estimator

def huber_corruption(Input, M, alpha, epsilon, setting, corru_setting):
    p_positive = alpha
    if setting == 'LTC':
        if corru_setting == 'Strong':
            U_prime = np.where(np.random.rand(len(Input)) < p_positive, np.abs(Input), Input)
        elif corru_setting == 'Weak':
            U_prime = np.where(np.random.rand(len(Input)) < p_positive, -Input, Input)
        else:
            raise ValueError("Invalid setting. Choose 'Strong' or 'Weak'.")
    elif setting == 'CTL':
        if corru_setting == 'Strong':
            U_prime = np.where(np.random.rand(len(Input)) < p_positive, M, Input)
        elif corru_setting == 'Weak':
            U_prime = np.where(np.random.rand(len(Input)) < p_positive, -Input, Input)
        else:
            raise ValueError("Invalid setting. Choose 'Strong' or 'Weak'.")
    else:
        raise ValueError("Invalid setting. Choose 'LTC' or 'CTL'.")
    return U_prime

def Whole_Procedure(alpha, epsilon, k, setting, corru_setting):
    sample_size = int(1e4)
    z_list = []
    mean_error = []
    for i in range(1, sample_size+1):
        x_i = inlier_distribution(sample_size=1, alpha=alpha, epsilon=epsilon, k=k, setting=setting)
        M = compute_M(epsilon=epsilon, alpha=alpha, k=2, n=i, delta=0.05, setting=setting)

        if setting == 'CTL':
            x_i = huber_corruption(x_i, M, alpha, epsilon, setting='CTL', corru_setting = corru_setting)

        y_i = epsilon_ldp_mechanism(x_i, M, epsilon)

        if setting == 'LTC':
            y_i = huber_corruption(y_i, M, alpha, epsilon, setting='LTC', corru_setting = corru_setting)

        #z_i = analyzer(y_i, M, epsilon)
        z_list.append(y_i)
        new_mean = analyzer(z_list, M, epsilon)
        mean_error.append(np.abs(new_mean)) # E[X] = 0 for inlier distribution
    return mean_error        

# Example usage
alpha = 0.02
epsilon = 0.2
k = 2
sample_size = int(1e4)
mean_error = Whole_Procedure(alpha=alpha, epsilon=epsilon, k=k, setting='CTL', corru_setting = 'Strong')