#!/usr/bin/env python3
# -*- coding: utf-8 -*-

##Implementation of fixed-size subsampling (with and without replacement) Renyi Differential Privacy (FS-RDP) method from the paper
##Renyi DP-SGD with Fixed-Size Minibatches: Tighter Guarantees with or without Replacement

import math
import numpy as np
from scipy.special import comb as comb






#Definition of functions involved in the bounds from Thm 3.3
def log_k_choose_n(k,n):
    val=0
    for j in range(n):
        val=val+np.log(k-j)
    for j in range(1,n+1):
        val=val-np.log(j)
    return val

def log_factorial(k):
    val=0
    for j in range(1,k):
        val=val+np.log(j+1)
    return val

def M_exact(sigma,k):
    M=(-1)**(k-1)*(k-1)
    
    for ell in range(2,k+1):
        M=M+(-1)**(k-ell)*np.exp(log_k_choose_n(k,ell)+ell*(ell-1)*(2./sigma)**2/2)
    return M



def log_M_exact(sigma,k):
    exponents=[]
    for ell in range(2,k+1):
        exponents.append(log_k_choose_n(k,ell)+ell*(ell-1)*(2./sigma)**2/2)
        
    exp_max=np.max(exponents)
    
    
    M=(-1)**(k-1)*np.exp(np.log(k-1.)-exp_max)
    j=0
    for ell in range(2,k+1):
        M=M+(-1)**(k-ell)*np.exp(exponents[j]-exp_max)
        j=j+1
    return np.log(M)+exp_max


def B_bound(sigma,m):
    if m%2==0:
        return M_exact(sigma,m)
    else:
        return M_exact(sigma,m-1)**(1/2)*M_exact(sigma,m+1)**(1/2)
    
def log_B_bound(sigma,m):
    if m%2==0:
        return log_M_exact(sigma,m)
    else:
        return 0.5*(log_M_exact(sigma,m-1)+log_M_exact(sigma,m+1))
    

def A_bound(alpha,sigma,m):
    if m%2==0:
        A=0
        for ell in range(m+1):
            A=A+(-1)**(m-ell)*np.exp(log_k_choose_n(m,ell)+(alpha+ell-m-1)*(alpha+ell-m)*(2./sigma)**2/2)
        
        return A
    else:
        return A_bound(alpha,sigma,m-1)**(1/2)*A_bound(alpha,sigma,m+1)**(1/2)
    
def log_A_bound(alpha,sigma,m):
    if m%2==0:
        exponents=[]
        for ell in range(m+1):
            exponents.append(log_k_choose_n(m,ell)+(alpha+ell-m-1)*(alpha+ell-m)*(2./sigma)**2/2)
        exp_max=np.max(exponents)    
        
        A=0
        j=0
        for ell in range(m+1):
            A=A+(-1)**(m-ell)*np.exp(exponents[j]-exp_max)
            j=j+1
        
        return np.log(A)+exp_max
    else:
        return 0.5*(log_A_bound(alpha,sigma,m-1)+log_A_bound(alpha,sigma,m+1))
        

def R_bound(alpha,sigma,m,q):
    abs_alpha_prod=alpha
    for j in range(1,m):
        abs_alpha_prod=abs_alpha_prod*np.abs(alpha-j)
    if 0<alpha-m<1:
        return q**m/np.math.factorial(m)*abs_alpha_prod*(A_bound(alpha,sigma,m)+B_bound(sigma,m))

    else:
        return q**m/np.math.factorial(m)*abs_alpha_prod*(q/(m+1)*A_bound(alpha,sigma,m)+(1-q/(m+1))*B_bound(sigma,m))

def log_R_bound(alpha,sigma,m,q):
    abs_alpha_prod=alpha
    for j in range(1,m):
        abs_alpha_prod=abs_alpha_prod*np.abs(alpha-j)
    if 0<alpha-m<1:
        exponent_A=log_A_bound(alpha,sigma,m)
        exponent_B=log_B_bound(sigma,m)
        exp_max=np.maximum(exponent_A,exponent_B)
        exp_min=np.minimum(exponent_A,exponent_B)

        return np.log(q)*m-log_factorial(m)+np.log(abs_alpha_prod)+np.log(1.+np.exp(exp_min-exp_max))+exp_max

    else:
        exponent_A=log_A_bound(alpha,sigma,m)
        exponent_B=log_B_bound(sigma,m)
        exp_max=np.maximum(exponent_A,exponent_B)
        return np.log(q)*m-log_factorial(m)+np.log(abs_alpha_prod)+np.log(q/(m+1)*np.exp(log_A_bound(alpha,sigma,m)-exp_max)+(1-q/(m+1))*np.exp(log_B_bound(sigma,m)-exp_max))+exp_max
                    

def H_bound(alpha,sigma,m,q):
    H_terms=[1.]
    alpha_prod=alpha
    for k in range(2,m):
        alpha_prod=alpha_prod*(alpha-k+1)
        H_terms.append(q**k/np.math.factorial(k)*alpha_prod*M_exact(sigma,k))
    H_terms.append(R_bound(alpha,sigma,m,q))
    

    return np.sum(H_terms)

def H_bound_minus_1(alpha,sigma,m,q):
    H_terms=[]
    alpha_prod=alpha
    for k in range(2,m):
        alpha_prod=alpha_prod*(alpha-k+1)
        H_terms.append(q**k/np.math.factorial(k)*alpha_prod*M_exact(sigma,k))
    H_terms.append(R_bound(alpha,sigma,m,q))
    

    return np.sum(H_terms)

def log_H_bound_minus_1(alpha,sigma,m,q):
    exponents=[]
    signs=[]
    alpha_prod=alpha
    for k in range(2,m):
        alpha_prod=alpha_prod*(alpha-k+1)
        if not alpha_prod==0:
            exponents.append(k*np.log(q)-log_factorial(k)+np.log(np.abs(alpha_prod))+log_M_exact(sigma,k))
            signs.append(np.sign(alpha_prod))
    if not (alpha<m and alpha==int(alpha) ):
        exponents.append(log_R_bound(alpha,sigma,m,q))
        signs.append(1.)
    
    exp_max=np.max(exponents)    
    return np.log(np.sum(signs*np.exp(exponents-exp_max)))+exp_max

##One step RDP bound for fixed-size subsampling without replacement (Thm 3.3)
def FS_RDP_woR(alpha,sigma,m,q):
    #the bound performs poorly when alpha-m\in(0,1), so in that case we increase m by 1 to avoid that case
    if 0<alpha-m<1:
        return 1/(alpha-1)*np.log(H_bound(alpha,sigma,m+1,q))
    else:
        return 1/(alpha-1)*np.log(H_bound(alpha,sigma,m,q))
    
#Thm 3.1 bounds for integer alpha (Taylor remainder is zero)
def K_integer_alpha(alpha,sigma,q):
    K_terms=[]
    for k in range(alpha+1):
        K_term=comb(alpha, k, exact=True)*(1-q)**(alpha-k)*q**k*np.exp(k*(k-1)*(2./sigma)**2/2)
        K_terms.append(K_term)
    
    
    
    #Sum terms in reverse order for numerical stability
    K=0
    for j in range(len(K_terms)):
        K=K+K_terms[len(K_terms)-1-j] 
    K=np.log(K)
    return K

# bounds obtained by combining formula for integer alpha with convexity technique of Wang et al
def FS_RDP_woR_convexity_method(alpha,sigma,q):
    if alpha>=2:
        if int(alpha)==alpha:
            return 1./(alpha-1.)*K_integer_alpha(int(alpha),sigma,q)
        else:
            return (1.-(alpha-math.floor(alpha)))/(alpha-1)*K_integer_alpha(int(math.floor(alpha)),sigma,q)+(alpha-math.floor(alpha))/(alpha-1)*K_integer_alpha(int(math.floor(alpha)+1),sigma,q)
    else:
        return FS_RDP_woR_convexity_method(2,sigma,q)
    
    




##One step RDP bound for fixed-size subsampling with replacement (Theorem 3.4)
#K is the number of terms that are not bounded using the worst-case result
def FS_RDP_wR(alpha,sigma,m,B,D,K=None):
    if K is None:
        K=np.minimum(m-1,B)
        
    #compute q_tilde
    q_tilde=0.
    for n in range(1,K+1):
        q_tilde=q_tilde+comb(B,n,exact=True)*D**(-n)*(1.-1./D)**(B-n)
    q_tilde=1./(1.+(1.-1./D)**B/q_tilde)

    exponents=[0.]
    for n in range(1,K+1):

        log_a_n= log_k_choose_n(B,n)+np.log(D)*(-n)+np.log(1.-1./D)*(B-n)
        log_a_n_tilde=log_a_n-np.log(q_tilde)
        tmp=log_H_bound_minus_1(alpha,sigma/n,m,q_tilde)
        #take minimum of Taylor expansion bound and worst-case bound of each term
        tmp=np.minimum(tmp,np.log(1.-np.exp(-2*alpha*(alpha-1)*n**2/sigma**2))+2*alpha*(alpha-1)*n**2/sigma**2)
        exponents.append(log_a_n_tilde+tmp)
    for n in range(K+1,B+1):
        log_a_n= log_k_choose_n(B,n)+np.log(D)*(-n)+np.log(1.-1./D)*(B-n)
        tmp=np.log(1.-np.exp(-2*alpha*(alpha-1)*n**2/sigma**2))+2*alpha*(alpha-1)*n**2/sigma**2
        exponents.append(log_a_n+tmp)

    exp_max=np.max(exponents)
    

    
    return 1./(alpha-1.)*(np.log(np.sum(np.exp(exponents-exp_max)))+exp_max)




#Lower bound on FS-RDP with replacement from Theorem 3.5 and Appendix C.2.1.
    
#lb computed using the relaxation from Appendix C.2.1: indices_function(k) should return indices of terms used in stage k of the recursion
def log_FS_wR_LB_inductive_approx(k,c,d,B,D, indices_function):
    if k==2:
        log_F_terms=[]
        indices_k=indices_function(k)
        for n in indices_k:
            log_a_n=log_k_choose_n(B,n)-n*np.log(D)+(B-n)*np.log(1-1/D)
            
            log_F_terms.append(log_a_n+d*n+B*np.log((1-1/D)*np.exp(-(c*n+d))+D**(-1))+B*(c*n+d))
       
        max_log_F=np.max(log_F_terms)    
        F_terms=np.exp(log_F_terms-max_log_F)

        return np.log(np.sum(F_terms))+max_log_F
    else:
        log_F_terms=[]
        indices_k=indices_function(k)
        for n in indices_k:
            log_F_terms.append(log_k_choose_n(B,n)+np.log(D)*(-n)+np.log(1-1/D)*(B-n)+d*n+log_FS_wR_LB_inductive_approx(k-1,c,d+c*n,B,D,indices_function))
        max_log_F=np.max(log_F_terms)    
        
        F_terms=np.exp(log_F_terms-max_log_F)

        return np.log(np.sum(F_terms))+max_log_F
    
#FSR lower bound computed using selected indices
def FS_RDP_wR_LB_approximate(alpha,sigma,B,D,indices_function):
    return 1/(alpha-1)*log_FS_wR_LB_inductive_approx(alpha,4/sigma**2,0,B,D,indices_function)


#exact formula for the Theorem 3.4 lower bound on FS-RDP with replacement  (don't use if unless |B| and alpha are small!)
#exact lb for alpha=2
def FS_wR_LB_2_exact(c,d,B,D):
    log_F_terms=[]
   
    for n in range(B+1):
        log_a_n=log_k_choose_n(B,n)-n*np.log(D)+(B-n)*np.log(1-1/D)
        
        log_F_terms.append(log_a_n+d*n+B*np.log((1-1/D)*np.exp(-(c*n+d))+D**(-1))+B*(c*n+d))
   
    max_log_F=np.max(log_F_terms)    
    F_terms=np.exp(log_F_terms-max_log_F)

    return np.log(np.sum(F_terms))+max_log_F

def FS_wR_LB_inductive_exact(k,c,d,B,D):
    if k==2:
        return np.exp(FS_wR_LB_2_exact(c,d,B,D))
    else:
        F=0
        for n in range(B+1):
            F=F+comb(B, n, exact=True)*D**(-n)*(1-1/D)**(B-n)*np.exp(d*n)*FS_wR_LB_inductive_exact(k-1,c,d+c*n,B,D)
        return F

def FS_RDP_wR_LB_exact(alpha,sigma,B,D):
    return 1/(alpha-1)*np.log(FS_wR_LB_inductive_exact(alpha,4/sigma**2,0,B,D))




