# libraries 
import numpy as np
import matplotlib.pyplot as plt
import matplotlib, time
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import dppy
from dppy.finite_dpps import FiniteDPP
from sklearn.datasets import load_boston
import scipy
import sys


# function definitions 
def convert_outputs2averages(outputs,verbose=True): 
    # this function returns "averaged" solutions
    # the i'th vector (of the returned result) is the average of outputs upto the i'th output
    summed_outputs = [None for ii in range(len(outputs))]
    summed_outputs[0] = outputs[0].copy()
    if verbose:
        print("averaging progress: ", end="")
    for i in range(1, len(outputs)):
        if verbose:
            print(i, end=", ")
        summed_outputs[i] = summed_outputs[i-1] + outputs[i]
    if verbose:
        print()
    return [summed_outputs[i]/(i+1) for i in range(len(outputs))]
def convert_outputs2averages_weighted(outputs,weights,verbose=True): 
    # this function returns "averaged" solutions
    # the i'th vector (of the returned result) is the average of outputs upto the i'th output
    summed_outputs = [None for ii in range(len(outputs))]
    summed_outputs[0] = outputs[0].copy()
    if verbose:
        print("averaging progress: ", end="")
    for i in range(1, len(outputs)):
        if verbose:
            print(i, end=", ")
        summed_outputs[i] = summed_outputs[i-1] + outputs[i] * weights[i]
    if verbose:
        print()
    return [summed_outputs[i]/(np.sum(weights[0:i+1])) for i in range(len(outputs))]
def compute_effective_dimension(ATA, lm):

    return np.trace(np.matmul(ATA, np.linalg.inv(ATA + lm*np.identity(ATA.shape[0]))))

# sketching functions:
def gaussian_sketch(A, y, m): 
    n,d = A.shape
    
    S = np.random.normal(0,1,(m,n)) / np.sqrt(m)
    SA = np.matmul(S, A)
    Sy = np.matmul(S, y)
    return SA, Sy
def uniform_sampling(A, y, m): 
    n,d = A.shape
    
    sel_rows = np.random.choice(n, m, replace=False)
    SA = A[sel_rows,:] * np.sqrt(n / m)
    Sy = y[sel_rows,:] * np.sqrt(n / m)
    return SA, Sy
def ridgelev_sketch(A, y, m, scores): 
    n,d = A.shape
    
    sel_rows = np.random.choice(n, size=m, replace=False, p=scores)
    SA = (np.sqrt(1/(m*scores[sel_rows])) * A[sel_rows,:].T).T
    Sy = (np.sqrt(1/(m*scores[sel_rows])) * y[sel_rows,:].T).T
    return SA, Sy
def ridgelev_precompute(AAT, lm1):
    n = AAT.shape[0]
    
    K = AAT
    scores = np.diag(np.matmul(K, np.linalg.inv(K + lm1*n*np.identity(n))))
    scores = scores / np.sum(scores)
    
    return scores
def rademacher_sketch(A, y, m): 
    n,d = A.shape
    
    S = np.random.choice(np.array([-1,1]), size=(m,n)) / np.sqrt(m)
    SA = np.matmul(S, A)
    Sy = np.matmul(S, y)
    return SA, Sy
def surrogate_p1_sketch(A, y, m, DPP, effective_dim): 
    n,d = A.shape
    
    S_samples = np.array(DPP.sample_exact())

    gamma = m - effective_dim
    M = np.random.poisson(gamma)

    # surrogate sketch for uniform sampling p
    other_samples = np.random.choice(n, M, replace=False)

    sigmas = np.concatenate((S_samples, other_samples), axis=0)
    sigmas = sigmas.astype(int)
    SA = A[sigmas,:] * np.sqrt(n / (sigmas.shape[0]))
    Sy = y[sigmas,:] * np.sqrt(n / (sigmas.shape[0]))
    return SA, Sy
def surrogate_p1_precompute(AAT, lm1): 
    DPP_params = {}
    DPP_params["L"] = AAT / lm1
    DPP = FiniteDPP(kernel_type='likelihood', projection=False, **DPP_params)
    
    return DPP
def surrogate_p2_sketch(A, y, m, effective_dim, t):
    n,d = A.shape
    
    k = int(np.ceil(effective_dim))
    S_samples = np.random.choice(n, k, replace=False)
            
    # initial computation of determinant and matrix inversion
    A_S = A[S_samples, :]
    det_A_S = np.linalg.slogdet(np.matmul(A_S, A_S.T))[1]
    inv_A_SA_ST = np.linalg.inv(np.matmul(A_S, A_S.T))
            
    for iter_no in range(t):
        # step 2.a                
        ind = np.ones((n,), bool)
        ind[S_samples] = False
        to_sample_from = np.linspace(0,n-1,n, dtype='int')
        to_sample_from = to_sample_from[ind]
                
        ii = np.random.choice(to_sample_from, 1, replace=False)[0] # should we sample from [n] \ S or [n]
        jj_ind = np.random.choice(k+1, 1, replace=False)[0] # changed to k+1          
        # adding i
        a_i = A[ii:ii+1, :]
        A_S = A[S_samples, :]
        det_i, inv_i = S_plus_ii_update(det_A_S, inv_A_SA_ST, A_S, a_i)
                                
        # removing j
        #S_samples_i_added = np.concatenate((S_samples, np.array([ii])))
        S_samples_i_added = np.zeros((k+1), dtype=int)
        S_samples_i_added[0:k] = S_samples.copy()
        S_samples_i_added[k] = ii
        
        A_S_i_added = A[S_samples_i_added, :]
        a_j = A[S_samples_i_added[jj_ind]:S_samples_i_added[jj_ind]+1, :]
        det_i_j, inv_i_j = S_minus_jj_update(det_i, inv_i, A_S_i_added, a_j, jj_ind)
        
        ratio = np.exp(det_i_j - det_A_S)
                                
        prob = min(1, ratio)
        if np.random.binomial(1, prob, size=1) == 1: # move from S to T
            if jj_ind != A_S.shape[0]: # otherwise no update
                # update the determinant and inverse for T = S+ii-jj
                S_samples_temp = np.zeros((k), dtype=int)
                S_samples_temp[0:jj_ind] = S_samples[0:jj_ind].copy()
                S_samples_temp[jj_ind:-1] = S_samples[jj_ind+1:].copy()
                S_samples_temp[-1] = ii
                S_samples = S_samples_temp.copy()
                        
                det_A_S = det_i_j
                inv_A_SA_ST = inv_i_j.copy()
                        
                
    # step 3
    other_samples = np.random.choice(n, m-k, replace=False)

    # step 4, 5
    #sigmas = np.concatenate((S_samples, other_samples), axis=0)
    sigmas = np.zeros((m), dtype=int)
    sigmas[:k] = S_samples.copy()
    sigmas[k:] = other_samples.copy()
    
    #sigmas = sigmas.astype(int)

    SA = A[sigmas,:] * np.sqrt(n / (sigmas.shape[0]))
    Sy = y[sigmas,:] * np.sqrt(n / (sigmas.shape[0]))
    
    return SA, Sy
def surrogate_p3_sketch(A, y, m, DPP, effective_dim):
    n,d = A.shape
    
    k = int(np.ceil(effective_dim))
    S_samples = np.array(DPP.sample_exact_k_dpp(k))

    other_samples = np.random.choice(n, m-k, replace=False)

    sigmas = np.concatenate((S_samples, other_samples), axis=0)
    sigmas = sigmas.astype(int)

    SA = A[sigmas,:] * np.sqrt(n / (sigmas.shape[0]))
    Sy = y[sigmas,:] * np.sqrt(n / (sigmas.shape[0]))
    
    return SA, Sy
def surrogate_p3_precompute(AAT, lm1): 
    DPP_params = {}
    DPP_params["L"] = AAT / lm1
    DPP = FiniteDPP(kernel_type='likelihood', projection=False, **DPP_params)
    
    return DPP

# surrogate_p2 sketch helper functions -- mcmc approach
def S_plus_ii_update(det_A_S, inv_A_SA_ST, A_S, a_i):
    # A_S_ii means ii added to A_S, A_S_jj means jj removed from A_S
    n, d = A_S.shape
    # update det
    A_Sa_iT = np.matmul(A_S, a_i.T)
    H_S_ii = np.matmul(a_i, a_i.T) - np.matmul(A_Sa_iT.T, np.matmul(inv_A_SA_ST, A_Sa_iT))

    det_A_S_ii = det_A_S + np.log(H_S_ii[0,0])
    
    # update inv
    top_left = H_S_ii*inv_A_SA_ST + np.matmul(inv_A_SA_ST, np.matmul(A_Sa_iT, np.matmul(A_Sa_iT.T, inv_A_SA_ST)))
    top_right = -np.matmul(inv_A_SA_ST, A_Sa_iT)
    bot_left = -np.matmul(A_Sa_iT.T, inv_A_SA_ST)
    bot_right = np.array([[1.]])
    
    inv_A_S_iiA_S_iiT = np.zeros((n+1,n+1))
    inv_A_S_iiA_S_iiT[0:n,0:n] = top_left.copy()
    inv_A_S_iiA_S_iiT[0:n,n:n+1] = top_right.copy()
    inv_A_S_iiA_S_iiT[n:n+1,0:n] = bot_left.copy()
    inv_A_S_iiA_S_iiT[n:n+1,n:n+1] = bot_right.copy()
    
    inv_A_S_iiA_S_iiT = inv_A_S_iiA_S_iiT / H_S_ii
    
    return det_A_S_ii, inv_A_S_iiA_S_iiT
def S_minus_jj_update(det_A_S, inv_A_SA_ST, A_S, a_j, jj):
    # A_S_ii means ii added to A_S, A_S_jj means jj removed from A_S
    n, d = A_S.shape
    # remove the jj'th sample
    A_S_jj = np.zeros((n-1,d))
    A_S_jj[0:jj, :] = A_S[0:jj, :].copy()
    A_S_jj[jj:, :] = A_S[jj+1:, :].copy()
    
    # update det
    A_S_jja_jT = np.matmul(A_S_jj, a_j.T)
    H_S_jj = 1 / inv_A_SA_ST[jj, jj]
    det_A_S_jj = det_A_S - np.log(H_S_jj)
    
    ind = np.ones((n,), bool)
    ind[jj] = False
    
    # update inv
    top_left = inv_A_SA_ST[ind, :][:, ind].copy() * H_S_jj
    top_right = inv_A_SA_ST[ind, jj:jj+1].copy() * H_S_jj
    bot_left = inv_A_SA_ST[jj:jj+1, ind].copy() * H_S_jj
    bot_right = inv_A_SA_ST[jj:jj+1, jj:jj+1].copy() * H_S_jj
    
    inv_A_S_jjA_S_jjT = (top_left - np.matmul(top_right, bot_left)) / H_S_jj
    
    return det_A_S_jj, inv_A_S_jjA_S_jjT

def get_errors(A, y, m, lm1, lm2, num_workers, sketch_type, effective_dim, ATA, ATy, AAT, t=0, A_test=None,y_test=None):
    n,d = A.shape
    # optimal solution x_star
    x_star = np.linalg.solve(ATA + lm1*np.identity(d), ATy)

    # precomputing for some of the sketches
    if sketch_type == "ridgelev":
        scores = ridgelev_precompute(AAT, lm1)
    if sketch_type == "surrogate_p1":
        DPP = surrogate_p1_precompute(AAT, lm1)
    if sketch_type == "surrogate_p3":
        DPP = surrogate_p3_precompute(AAT, lm1)
        
    
    # averaged solution
    start_time = time.time()
    
    outputs = [None] * num_workers
    for j in range(num_workers):
        #if j % 50 == 0: 
            #print("[{}/{}]".format(j, num_workers), end=", ")

        if sketch_type == "gaus":
            SA, Sy = gaussian_sketch(A, y, m)
        if sketch_type == "unif":
            SA, Sy = uniform_sampling(A, y, m)
        if sketch_type == "ridgelev":
            SA, Sy = ridgelev_sketch(A, y, m, scores)
        if sketch_type == "rademacher":
            SA, Sy = rademacher_sketch(A, y, m)
        if sketch_type == "surrogate_p1":
            SA, Sy = surrogate_p1_sketch(A, y, m, DPP, effective_dim)
        if sketch_type == "surrogate_p2":
            SA, Sy = surrogate_p2_sketch(A, y, m, effective_dim, t)
        if sketch_type == "surrogate_p3":
            SA, Sy = surrogate_p3_sketch(A, y, m, DPP, effective_dim)
        
        
        # compute x_tilde
        x_tilde = np.linalg.solve(np.matmul(SA.T,SA) + lm2*np.identity(d), np.matmul(SA.T,Sy))

        
        # save x_tilde in outputs
        outputs[j] = x_tilde.copy()
    
    end_time = time.time()
    #print("\n Average worker time: "+ str((end_time - start_time) / num_workers))
    
    # convert outputs to averages
    averages = convert_outputs2averages(outputs, False)
    # compute errors
    errors = np.zeros((num_workers))
    acc_train = np.zeros((num_workers))
    acc_val = np.zeros((num_workers))
    for j in range(len(averages)):
        errors[j] = np.sqrt(np.sum((x_star - averages[j])**2)) / np.sqrt(np.sum(x_star**2))
        
        if A_test is not None:
            acc_train[j] = np.sum((np.matmul(A, x_star) >= 0.5) == y) / y.shape[0]
            acc_val[j] = np.sum((np.matmul(A_test, x_star) >= 0.5) == y_test) / y_test.shape[0]
    
    return errors, end_time - start_time, acc_train, acc_val
def get_errors_singlenewtonstep(A, y, m, lm1, lm2, num_workers, sketch_type, effective_dim, ATA, ATy, AAT, t=0, A_test=None,y_test=None, determ_avg=False):
    n,d = A.shape
    gradient = -np.matmul(A.T,y)
    
    # optimal solution x_star
    #x_star = np.linalg.solve(ATA + lm1*np.identity(d), ATy)
    # optimal solution
    x_star = np.matmul(np.linalg.inv(np.matmul(A.T, A) + lm1*np.identity(d)), gradient)
    
    
    # precomputing for some of the sketches
    if sketch_type == "ridgelev":
        scores = ridgelev_precompute(AAT, lm1)
    if sketch_type == "surrogate_p1":
        DPP = surrogate_p1_precompute(AAT, lm1)
    if sketch_type == "surrogate_p3":
        DPP = surrogate_p3_precompute(AAT, lm1)
    
    if determ_avg == True: # determinantal averaging
        det_weights = np.zeros((num_workers))
    
    # averaged solution
    start_time = time.time()
    
    outputs = [None] * num_workers
    for j in range(num_workers):
        #if j % 50 == 0: 
            #print("[{}/{}]".format(j, num_workers), end=", ")

        if sketch_type == "gaus":
            SA, Sy = gaussian_sketch(A, y, m)
        if sketch_type == "unif":
            SA, Sy = uniform_sampling(A, y, m)
        if sketch_type == "ridgelev":
            SA, Sy = ridgelev_sketch(A, y, m, scores)
        if sketch_type == "rademacher":
            SA, Sy = rademacher_sketch(A, y, m)
        if sketch_type == "surrogate_p1":
            SA, Sy = surrogate_p1_sketch(A, y, m, DPP, effective_dim)
        if sketch_type == "surrogate_p2":
            SA, Sy = surrogate_p2_sketch(A, y, m, effective_dim, t)
        if sketch_type == "surrogate_p3":
            SA, Sy = surrogate_p3_sketch(A, y, m, DPP, effective_dim)
        
        
        # compute x_tilde
        #x_tilde = np.linalg.solve(np.matmul(SA.T,SA) + lm2*np.identity(d), np.matmul(SA.T,Sy))
        # compute x_tilde
        H_hat = np.matmul(SA.T,SA) + lm2*np.identity(d)
        x_tilde = np.linalg.solve(H_hat, gradient)
        x_tilde = x_tilde * lm2 / lm1
        
        if determ_avg == True: # determinantal averaging
            det_weights[j] = np.linalg.det(H_hat)
            #_, det_weights[j] = np.linalg.slogdet(H_hat)

        
        # save x_tilde in outputs
        outputs[j] = x_tilde.copy()
    
    end_time = time.time()
    #print("\n Average worker time: "+ str((end_time - start_time) / num_workers))
    
    

    
    # convert outputs to averages
    averages = convert_outputs2averages(outputs, False)
    
    if determ_avg == True: # determinantal averaging
        averages = convert_outputs2averages_weighted(outputs, det_weights, False)
        #print(np.sum(det_weights))
    
    # compute errors
    errors = np.zeros((num_workers))
    acc_train = np.zeros((num_workers))
    acc_val = np.zeros((num_workers))
    for j in range(len(averages)):
        errors[j] = np.sqrt(np.sum((x_star - averages[j])**2)) / np.sqrt(np.sum(x_star**2))
        
        if A_test is not None:
            acc_train[j] = np.sum((np.matmul(A, x_star) >= 0.5) == y) / y.shape[0]
            acc_val[j] = np.sum((np.matmul(A_test, x_star) >= 0.5) == y_test) / y_test.shape[0]
    
    return errors, end_time - start_time, acc_train, acc_val

# function for loading datasets
def dataset_loader(dataset_name):
    if dataset_name == "boston":
        # boston house prices 
        A, y = load_boston(return_X_y=True)
        y = np.expand_dims(y, axis=1)
    
    elif dataset_name == "cifar":
        # cifar-10 
        def unpickle(file): 
            import pickle
            with open(file, 'rb') as fo:
                dict = pickle.load(fo, encoding='bytes')
            return dict
        directory = "datasets/cifar-10-batches-py/"

        A = np.zeros((50000, 3072))
        train_labels = np.zeros((50000), dtype=int)
        for i in range(1,6): # the data is in 5 batches
            filename = directory + "data_batch_" + str(i)
            batch = unpickle(filename)
            A[(i-1)*10000:i*10000, :] = batch[b'data'].astype(np.float64)
            train_labels[(i-1)*10000:i*10000] = np.array(batch[b'labels'])
        y = train_labels.copy()

        inds = np.argwhere(y <= 1)[:,0] # get the classes 0 and 1
        A = A[inds, :]
        y = y[inds].reshape((inds.shape[0], 1))

        col_means = np.mean(A, axis=0) # standardize
        A = A - col_means
        col_stddevs = np.sqrt(np.sum(A**2, axis=0))
        A = A / col_stddevs

    else:
        A = np.load("datasets/data_{}.npy".format(dataset_name))
        y = np.load("datasets/output_{}.npy".format(dataset_name))

    print("Dataset dimensions are: A={}, y={}".format(A.shape, y.shape))

    return A, y


# Distributed newton sketch for regularized problems
def backtracking_linesearch(tau, c, theta, f_theta, descend_direction, gradient, cost_func, a0=1):
    # implementation based on section 5.2 on the giant paper
    # divide alpha by tau every iteration until the condition is met
    # cost_func is the cost function
    # f_theta is the cost at theta (that is, f(theta))
    alpha = a0
    while cost_func(theta+alpha*descend_direction) > (f_theta + alpha*c*np.matmul(gradient.T,descend_direction)):  
        alpha = alpha / tau
    
    return alpha
# logistic regression
def log_regression_optimize(A,b,c_vec,mode,lm1,lm2,tau,c,num_iters,num_workers,m,a0=1,momentum=False, sketch_type="unif", m2=0):
    def specific_cost_fn(x):
        # uses A and train_labels
        A_x = np.matmul(A, x)
        K = np.maximum(np.zeros((n,1)), -A_x)

    #     p = np.exp(-K - np.log(np.exp(-K) + np.exp(-A_theta-K)))

        log_p = -K - np.log(np.exp(-K) + np.exp(-A_x-K))

        term1 = np.sum(b*log_p)
        term2 = np.sum((1-b) * (-A_x + log_p))
        return lm1/2*np.sum(x**2) - term1 - term2
    
    
    # mode: 1 for exact hessian, mode: 2 for sketched hessian averaging
    # backtracking line search is only for mode 2 (parameters: tau and c)
    n, d = A.shape
    x_list = [None] * num_iters
    costs = np.zeros(num_iters+1)

    x = np.zeros((d,1))
    costs[0] = specific_cost_fn(x)
    
    times = np.zeros((num_iters+1))
    times[0] = time.time()
    
    for iter_no in range(num_iters):
        print("iteration: [{}/{}]".format(iter_no, num_iters))

        # log-sum-exp trick
        A_x = np.matmul(A, x)
        K = np.maximum(np.zeros((n,1)), -A_x)
        p = np.exp(-K - np.log(np.exp(-K) + np.exp(-A_x-K)))
        
        D = np.diag((p * (1-p))[:,0])
        
        gradient = np.matmul(A.T, (p-b)) + lm1*x
        
        if mode == 1: # exact hessian
            hessian = np.matmul(A.T, np.matmul(D, A)) + lm1*np.identity(d)
            update_dir = np.linalg.solve(hessian, gradient)
            
            f_x = specific_cost_fn(x)
            mu = backtracking_linesearch(tau,c,x,f_x,-update_dir,gradient,specific_cost_fn,a0) 
            #print("mu={}".format(mu), end=", ")
            x = x + mu * (-update_dir)

        elif mode == 2: # sketched hessian averaging
            update_dir = np.zeros((d,1))
            Dhalf_A = np.matmul(D**(0.5), A)

            
            effective_dim = compute_effective_dimension(np.matmul(Dhalf_A.T, Dhalf_A), lm1)
            # compute lm2_star
            if lm2 == -999:
                lm2_star = lm1 * (1 - effective_dim/m)
                #print("eff_dim={}, lm2*={}".format(effective_dim, lm2_star))
                
                
            # precomputing for some of the sketches
            if sketch_type == "ridgelev":
                scores = ridgelev_precompute(np.matmul(Dhalf_A, Dhalf_A.T), lm1)
            if sketch_type == "surrogate_p1":
                DPP = surrogate_p1_precompute(np.matmul(Dhalf_A, Dhalf_A.T), lm1)
            if sketch_type == "surrogate_p3":
                DPP = surrogate_p3_precompute(np.matmul(Dhalf_A, Dhalf_A.T), lm1)
            
            
            
            
            for i in range(num_workers):
                if sketch_type == "gaus":
                    SDA, _ = gaussian_sketch(Dhalf_A, np.zeros((n,1)), m)
                if sketch_type == "unif":
                    SDA, _ = uniform_sampling(Dhalf_A, np.zeros((n,1)), m)
                if sketch_type == "ridgelev":
                    SDA, _ = ridgelev_sketch(Dhalf_A, np.zeros((n,1)), m, scores)
                if sketch_type == "rademacher":
                    SDA, _ = rademacher_sketch(Dhalf_A, np.zeros((n,1)), m)
                if sketch_type == "surrogate_p1":
                    SDA, _ = surrogate_p1_sketch(Dhalf_A, np.zeros((n,1)), m, DPP, effective_dim)
                if sketch_type == "surrogate_p2":
                    SDA, _ = surrogate_p2_sketch(Dhalf_A, np.zeros((n,1)), m, effective_dim, t)
                if sketch_type == "surrogate_p3":
                    SDA, _ = surrogate_p3_sketch(Dhalf_A, np.zeros((n,1)), m, DPP, effective_dim)
                    
                    
                if lm1 == lm2:
                    hessian_est = np.matmul(SDA.T, SDA) + lm1*np.identity(d)
                else:
                    hessian_est = np.matmul(SDA.T, SDA) + lm2_star*np.identity(d) # *2

                
                update_dir = update_dir + np.linalg.solve(hessian_est, gradient)

            update_dir = update_dir / num_workers
            
            if lm2 == -999:
                update_dir = update_dir * lm2_star / lm1 # the problem was: /2 was missing (because of 2*lm1): SUPER IMPORTANT DETAIL

            
            # determine step size via line search
            f_x = specific_cost_fn(x)
            mu = backtracking_linesearch(tau,c,x,f_x,-update_dir,gradient,specific_cost_fn,a0)
            #print("mu={}".format(mu), end=", ")
            
            
            x = x + mu * (-update_dir)
        

        elif mode == 3: # gradient descent with full gradient
            # determine step size via line search
            f_x = specific_cost_fn(x)
            
            update_dir = gradient.copy()
            mu = backtracking_linesearch(tau,c,x,f_x,-update_dir,gradient,
                                         specific_cost_fn,a0)
            #print(mu, end=", ")
            
            if momentum:
                temp_hold = x.copy()
                x = x + mu * (-update_dir) + beta*(x-x_prev)
                x_prev = temp_hold.copy()
            else:
                x = x + mu * (-update_dir)

        
        x_list[iter_no] = x.copy()
        # get costs/accuracies
        costs[iter_no+1] = specific_cost_fn(x)
        times[iter_no+1] = time.time()
        
    return x_list, costs, times


