import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
import torch
import random
import datetime
import pickle
import math

device = torch.device('cpu')

"""
Parameters
"""
num_test = 20 # 20

dt = datetime.datetime.now()
def align2char(s):
    if len(str(s)) == 1:
        return f"0{s}"
    else:
        return s
save_file_name = f"exp_{align2char(dt.month)}{align2char(dt.day)}{align2char(dt.hour)}{align2char(dt.minute)}"
print(f"The result of experiment will be saved in {save_file_name}.pdf")

### Regularization
lambda1 = 1e-2
lambda2 = 1e-5
gamma = 0.5

### PDA
T0_PDA = 640
T1_PDA = 10 # steps of Langevin dynamics in each loop
T2_PDA = 50 # minibatch size
eta_PDA = 1e-5

### Particle-SDCA
T0_PSDCA = 640 
T1_PSDCA = 10 # steps of Langevin dynamics in each loop
T2_PSDCA = 1000 # tilde n
eta_PSDCA = 1e-4

### Parameters of Ridge
ridge_lambdas = [0.1, 1, 10]
ridge_lambdas_str = ["0.1", "1", "10"]
ridge_color = ["darkseagreen", "limegreen", "lime"]

### Parameters of Nadaraya-Watson
NW_bandwidth = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5]
NW_bandwidth_str = ["0.5", "0.6", "0.7", "0.8", "0.9", "1.0", "1.1", "1.2", "1.3", "1.4", "1.5"]
NW_color = ["aqua", "darkturquoise",  "cadetblue"]
def NW_kernel(u):
    if abs(u) <= 1:
        return 1-u**2
    else:
        return 0

### Data
n = 100 # number of training data
nt = 100 # number of test data
basis_num = 300 # number of basis using for generating data
train_noise_scaler = 0.1

### Galerkin Approximation
N = 150

# about teacher-student setting
Mt = 1 # number of neurons in middle layer of teacher network
target_func = torch.sign

### Model
layer_size = [N,1]
M = 200
nonlinearity_1 = "tanh" 
nonlinearity_2 = "tanh" 
nonlinearity_scaler = 1.0
# bias = True

# Conditions about the input
beta = 3 # input is Gaussian process whose covariance operator is a kernel function of H_{K^beta}.
input_scaler = 1.0
norm_bound = input_scaler * 100

# eigenvalues
mu = torch.tensor(
    [ 4 / (((2*k+1)**2) * np.pi) for k in range(basis_num)]
)
# eigenvalues^{-1}
mu_inv = torch.tensor(
    [ (((2*k+1)**2) * np.pi) / 4 for k in range(basis_num)]
)

# Operator used in sampling
S1_PDA = torch.tensor(
    [ 1 / (1 + eta_PDA * lambda1 * mu_inv[k]) for k in range(N)]
)
S2_PDA = 1/(1 + eta_PDA * lambda1)

S1_PSDCA = torch.tensor(
    [ 1 / (1 + eta_PSDCA * lambda1 * mu_inv[k]) for k in range(N)]
)
S2_PSDCA = 1/(1 + eta_PSDCA * lambda1)

"""
Model
"""
### Neural Network
def model(X, lambda1, lambda2, params1, params2 = None, innervalues = False, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler):
    if nonlinearity_1 == "tanh":
        def activation_function_1(x):
            return torch.tanh(x) * nonlinearity_scaler
    if nonlinearity_1 in ["relu" , "ReLU" , "RELU"]:
        def activation_function_1(x):
            return torch.relu(x) * nonlinearity_scaler
    if nonlinearity_2 == "tanh":
        def activation_function_2(x):
            return torch.tanh(x) * nonlinearity_scaler
    if nonlinearity_2 in ["relu" , "ReLU" , "RELU"]:
        def activation_function_2(x):
            return torch.relu(x) * nonlinearity_scaler

    num_data = X.shape[0]
    innervalues_ = []
    forward = X.repeat((params1[0].shape[0], 1, 1))

    for i in range(0, len(params1)):
        innervalues_ = innervalues_ + [forward]
        forward = torch.einsum('mbi,mji->mbj', forward, params1[i])
        forward = activation_function_1(forward)
        forward = forward * activation_function_2(params2[i].repeat((num_data,1,1)).permute(1,0,2))

    if innervalues:
        return forward[:,:,0], innervalues_ 
    else:
        return forward[:,:,0]

def model_grad(X, lambda1, lambda2, params1, params2=None, innervalues=None, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler):
    if innervalues == None:
        _, innervalues_ = model(X, lambda1, lambda2, params1, params2 = params2, innervalues = True, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler)
    if nonlinearity_1=="tanh":
        def activation_function_1(x):
            return torch.tanh(x) * nonlinearity_scaler
        def activation_function_1_grad(x):
            return (1-torch.tanh(x)**2) * nonlinearity_scaler
    if nonlinearity_1 in ["relu" , "ReLU" , "RELU"]:
        def activation_function_1(x):
            return torch.relu(x) * nonlinearity_scaler
        def activation_function_1_grad(x):
            return x * (x > 0.0) * nonlinearity_scaler
    if nonlinearity_2 == "tanh":
        def activation_function_2(x):
            return torch.tanh(x) * nonlinearity_scaler
        def activation_function_2_grad(x):
            return (1-torch.tanh(x)**2) * nonlinearity_scaler
    if nonlinearity_2 in ["relu" , "ReLU" , "RELU"]:
        def activation_function_2(x):
            return torch.relu(x) * nonlinearity_scaler
        def activation_function_2_grad(x):
            return x * (x > 0.0) * nonlinearity_scaler
    num_data = X.shape[0]
    grad1 = []
    grad2 = []
    for i in range(len(params1)):
        if params2 == None:
            raise NotImplementedError
        else:
            inner_prod = torch.einsum('mij,mbj->mbi',params1[i],innervalues_[i])
            sigma_i_grad1 = activation_function_1_grad(inner_prod) * activation_function_2(params2[i].repeat((num_data,1,1)).permute(1,0,2))
            sigma_i_grad2 = activation_function_1(inner_prod) * activation_function_2_grad(params2[i].repeat((num_data,1,1)).permute(1,0,2))
            for j in range(i):
                raise NotImplementedError

            grad1 = grad1 + [torch.einsum('mbi,mbj->mbij',sigma_i_grad1,innervalues_[i])]
            grad2 = grad2 + [sigma_i_grad2]
    return grad1, grad2

### Loss functions and corresponding functions for conjugate, initialization, update
def MSEloss(Y,Z, gamma=0.5):
  return (Y-Z)**2/(2*gamma)

def MSEupdate(integral, n,y_i, g_i_t,max_iter_update = 0, gamma=0.5, eps = 1e-18):
  return (integral - y_i + g_i_t/(n*lambda2) )/(gamma+1/(n*lambda2))

def MSEgrad(Y,Z, gamma=0.5):
  return (Y-Z)/gamma

loss_func = MSEloss
update = MSEupdate
lossgrad = MSEgrad

"""
PDA Algorithm
"""
### Sampling
def LMC_PDA(g, X, A, t,  lambda1, lambda2, params1, params2 = None, eta=eta_PDA, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler):# Langevin Monte Carlo?
    num_data = X.shape[0]
    grad1, grad2 = model_grad(X, lambda1, lambda2, params1, params2=params2, innervalues=None, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler) 
    for i in range(len(params1)):
        dU1 = 2 * torch.einsum('mbij,b->mij',grad1[i],A) / (lambda2*(t+2)*(t+3))
        params1[i] = params1[i] - eta * dU1 + np.sqrt(2*eta) * torch.FloatTensor(params1[i].shape[0],params1[i].shape[1],params1[i].shape[2]).normal_().to(device)
        dU2 = 2 * torch.einsum('mbi,b->mi',grad2[i],A) / (lambda2*(t+2)*(t+3))
        params2[i] = params2[i] - eta * dU2 + np.sqrt(2*eta) * torch.FloatTensor(params2[i].shape[0],params2[i].shape[1]).normal_().to(device)

        params1[i] = torch.einsum('mbi,i->mbi', params1[i], S1_PDA)
        params2[i] = params2[i] * S2_PDA
    return params1, params2

def sample_PDA(g, X, A, t, M, layer_size, lambda1, lambda2, T1, method=LMC_PDA, params1=None, params2=None, eta=eta_PDA, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler):
    if params1 == None:
        params1 = []
        for i in range(len(layer_size)-1):
            params1_mean = torch.zeros(M,layer_size[i+1],layer_size[i])
            params1_std  = (torch.sqrt(mu[:N])).repeat(M,layer_size[i+1],1)
            params1.append(torch.normal(params1_mean, params1_std))

        # params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]
        # params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(1/lambda1)**0.5)    for i in range(len(layer_size)-1)]
        params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=1)    for i in range(len(layer_size)-1)]

    for i in range(T1):
        params1, params2  = method(g, X, A, t,lambda1, lambda2, params1, params2 =params2, eta=eta, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler)
    return params1, params2

### Main
def PDA(X,Y,Xt,Yt,M,lambda1, lambda2, layer_size,record_PDA, T0 = T0_PDA, T1 = T1_PDA, T2 = T2_PDA, eta = eta_PDA, gamma= gamma, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler):
    num_data = X.shape[0]
    num_data_test = Xt.shape[0]
    num_grad = 0
    params1 = None
    params2 = None
    A = torch.zeros(num_data)
    for T in range(0,T0):
        params1, params2 = sample_PDA(g, X, A, T ,M,layer_size,lambda1,lambda2, T1, params1=params1, params2 = params2, eta=eta, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler)
        num_grad += T1* num_data
        rs = torch.ones(M) / M
        train_loss = sum(loss_func(torch.einsum('mb,m->b',
                model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler)
                ,rs
            ), Y, gamma = gamma))/num_data
        test_loss = sum(loss_func(torch.einsum('mb,m->b',
                model(Xt,lambda1, lambda2,params1,params2 = params2, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler)
                ,rs
            ), Yt, gamma = gamma))/num_data_test

        record_PDA[0,T] = train_loss
        record_PDA[1,T] = test_loss
        record_PDA[2,T] = num_grad
        if T % 30 == 0:
            print("Step: %d, Gradient evaluation: %d, Training loss: %f, Test loss: %f" % (T,num_grad,train_loss,test_loss))

        indices = list(range(num_data))
        random.shuffle(indices)
        target_coordinate_list =  indices[:T2]
        diff = lossgrad(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler),rs),Y)
        A[target_coordinate_list] = A[target_coordinate_list] + (T+1) * diff[target_coordinate_list] / T2
    return params1, params2

"""
PSDCA Algorithm
"""
### Sampling
def LMC_PSDCA(g, X, lambda1, lambda2, params1, params2, eta=eta_PSDCA, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler):
  num_data = X.shape[0]
  grad1, grad2 = model_grad(X, lambda1, lambda2, params1, params2=params2, innervalues=None, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler) 
  for i in range(len(params1)):
    dU1 =  torch.einsum('mbij,b->mij',grad1[i],g) / (num_data*lambda2) 
    params1[i] = params1[i] - eta * dU1 + np.sqrt(2*eta) * torch.FloatTensor(params1[i].shape[0],params1[i].shape[1],params1[i].shape[2]).normal_().to(device)

    dU2 =  torch.einsum('mbi,b->mi',grad2[i],g) / (num_data*lambda2)
    params2[i] = params2[i] - eta * dU2 + np.sqrt(2*eta) * torch.FloatTensor(params2[i].shape[0],params2[i].shape[1]).normal_().to(device)

    params1[i] = torch.einsum('mbi,i->mbi', params1[i], S1_PSDCA)
    params2[i] = params2[i] * S2_PSDCA

  return params1, params2

def sample_PSDCA(g, X, M, layer_size, lambda1, lambda2, T1, method=LMC_PSDCA, params1=None, params2=None, eta=eta_PSDCA, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2,  nonlinearity_scaler=nonlinearity_scaler):
    if params1 == None:
        params1 = []
        for i in range(len(layer_size)-1):
            params1_mean = torch.zeros(M,layer_size[i+1],layer_size[i])
            params1_std  = (torch.sqrt(mu[:N])).repeat(M,layer_size[i+1],1)
            params1.append(torch.normal(params1_mean, params1_std))

        # params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(lambda2/(2*lambda1))**0.5)    for i in range(len(layer_size)-1)]
        # params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=(1/lambda1)**0.5)    for i in range(len(layer_size)-1)]
        params2 = [torch.FloatTensor(M,layer_size[i+1]).normal_(mean=0,std=1)    for i in range(len(layer_size)-1)]

    for t in range(T1):
        params1, params2  = method(g, X, lambda1, lambda2, params1, params2 =params2, eta=eta, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler)
    return params1, params2

### Main
def PSDCA(X,Y,Xt,Yt,g,M,lambda1, lambda2, layer_size,record_PSDCA, T0 = T0_PSDCA, T1 = T1_PSDCA, T2 = T2_PSDCA, eta = eta_PSDCA, gamma= gamma, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2,  nonlinearity_scaler=nonlinearity_scaler):
  num_data = X.shape[0]
  num_data_test = Xt.shape[0]
  num_grad = -T1*num_data
  params1 = None
  params2 = None

  for T in range(0,T0):
    if T==0:
        params1, params2 = sample_PSDCA(g, X, M,layer_size,lambda1,lambda2, 0, params1=params1, params2 = params2, eta=eta, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler)
    else:
        params1, params2 = sample_PSDCA(g, X, M,layer_size,lambda1,lambda2, T1, params1=params1, params2 = params2, eta=eta, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler)
    num_grad += T1* num_data
    rs = torch.ones(M) / M
    train_loss = sum(loss_func(torch.einsum('mb,m->b',model(X,lambda1, lambda2,params1,params2 = params2, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler),rs), Y, gamma = gamma))/num_data
    test_loss = sum(loss_func(torch.einsum('mb,m->b',model(Xt,lambda1, lambda2,params1,params2 = params2, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler),rs), Yt, gamma = gamma))/num_data_test

    record_PSDCA[0,T] = train_loss
    record_PSDCA[1,T] = test_loss
    record_PSDCA[2,T] = num_grad
    if T % 3 == 0:
        print("Step: %d, Gradient evaluation: %d, Training loss: %f, Test loss: %f" % (T,num_grad,train_loss,test_loss))

    indices = list(range(num_data))
    random.shuffle(indices) 
    for i_t in indices[:T2]:
        integral = sum(model(X[i_t:i_t+1,:],lambda1, lambda2, params1,params2 = params2, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler)[:,0] * rs) /sum(rs)
        delg = - g[i_t]
        g[i_t] = update(integral, num_data, Y[i_t], g[i_t])
        delg += g[i_t]
        rs = rs * torch.exp(- model(X[i_t:i_t+1,:], lambda1, lambda2, params1,params2, nonlinearity_1=nonlinearity_1, nonlinearity_2=nonlinearity_2, nonlinearity_scaler=nonlinearity_scaler)[:,0]*delg/(num_data*lambda2))
        rs = rs/sum(rs)
  return params1, params2

"""
Record & Compare
"""
psdca_res = []
pda_res = []
ridge_res = [[] for _ in range(len(ridge_lambdas))]
NW_res = [[] for _ in range(len(NW_bandwidth))]

for i in range(num_test):
    print(i)

    """
    Generating the training data and test data (teacher-student setup)
    """
    w_list = []
    a_list = []
    for _ in range(Mt):
        w_list.append([np.random.normal(loc=0.0, scale=5.0, size=None) for j in range(basis_num)])
        a_list.append(np.random.normal(loc=0.0, scale=5.0, size=None))

    ### Function that calculates the target value
    # alpha is a coefficient of inputs x for the basis of \mathcal H
    # alpha and w are generated above. 
    def target(alpha):
        ret = 0
        for k in range(Mt):
            ret1 = target_func(torch.tensor(a_list[k]))
            ret2 = target_func(torch.tensor(sum([ w_list[k][j]*alpha[j] for j in range(basis_num) ])))
            ret += ret1 * ret2
        return ret / Mt
        
    ### Generating data
    def generate_alpha():
        return [np.random.normal(loc=0.0, scale=input_scaler * (mu[j]**(beta/2)), size=None) for j in range(basis_num)]

    train_alpha = []
    train_y = []
    print("Training Data Generating...")
    for k in tqdm(range(n)):
        while True:
            alpha = generate_alpha()
            norm_square = sum([ (alpha[j]**2) * (mu_inv[j] ** beta-1) for j in range(basis_num)])
            if norm_square < norm_bound ** 2:
                break
        train_alpha.append(alpha)
        train_y.append(target(alpha) + np.random.normal(loc=0.0, scale=train_noise_scaler, size=None))

    
    test_alpha = []
    test_y = []
    print("Test Data Generating...")
    for k in tqdm(range(nt)):
        while True:
            alpha = generate_alpha()
            norm_square = sum([ (alpha[j]**2) * (mu_inv[j] ** beta-1) for j in range(basis_num)])
            if norm_square < norm_bound ** 2:
                break
        test_alpha.append(alpha)
        test_y.append(target(alpha))
    
    X = []
    for j in range(len(train_alpha)):
        X.append(train_alpha[j][:N])
    Xt = []
    for j in range(len(test_alpha)):
        Xt.append(test_alpha[j][:N])

    X = torch.tensor(X)
    Xt = torch.tensor(Xt)
    Y = torch.tensor(train_y)
    Yt = torch.tensor(test_y)

    g = torch.zeros_like(Y)

    """
    Learning
    """
    print("PSDCA", i)
    record_PSDCA = torch.zeros(6,T0_PSDCA+1)
    params1, params2 = PSDCA(X,Y,Xt,Yt,g,M, lambda1,lambda2, layer_size, record_PSDCA, T0 = T0_PSDCA, T1=T1_PSDCA, T2= T2_PSDCA , eta=eta_PSDCA, nonlinearity_scaler=nonlinearity_scaler, gamma=0.5)

    print("PDA", i)
    record_PDA = torch.zeros(6,T0_PDA+1)
    params1, params2 = PDA(X,Y,Xt,Yt,M, lambda1, lambda2, layer_size, record_PDA, T0 = T0_PDA, T1=T1_PDA, T2=T2_PDA,eta=eta_PDA, nonlinearity_scaler=nonlinearity_scaler, gamma=0.5)

    # Save the result of PDA and PSDCA
    torch.save(record_PSDCA ,f"record/PSDCA_{i}.pth")
    torch.save(record_PDA ,f"record/PDA_{i}.pth")

    psdca_res = psdca_res +[record_PSDCA]
    pda_res =pda_res + [record_PDA]

    ### Calculate the Ridge-regression estimator
    for j, ridge_lambda in enumerate(ridge_lambdas):
        ridge_param = torch.linalg.solve(torch.transpose(X, 0, 1) @ X + ridge_lambda * torch.eye(X.shape[1]), torch.transpose(X, 0, 1) @ Y)
        Yt_predict = Xt @ ridge_param
        ridge_res[j].append(sum(loss_func(Yt_predict, Yt))/nt)

    ### Calculate the Nadaraya-Watson estimator
    for idx, h in enumerate(NW_bandwidth):
        Yt_predict = []
        for l in range(nt):
            norms = [
                np.sqrt(sum([(train_alpha[k][j] - test_alpha[l][j])**2 for j in range(basis_num)]))/2 for k in range(n)
            ]
            Yt_predict.append(
                sum([train_y[k] * NW_kernel(norms[k] / h) for k in range(n)]) / sum([NW_kernel(norms[k] / h) for k in range(n)])
            )
        Yt_predict = torch.tensor(Yt_predict)
        
        NW_res[idx].append(sum(loss_func(Yt_predict, Yt))/nt)

# Save the result of Ridge regression estimator
with open(f'record/{save_file_name}_ridge.pickle', mode='wb') as f:
    pickle.dump(ridge_res,f)

# Save the result of Nadaraya-Watson estimator
with open(f'record/{save_file_name}_NW.pickle', mode='wb') as f:
    pickle.dump(NW_res,f)

### Drawing the result of PDA and PSDCA
psdca_ = torch.tensor([psdca_res[i].tolist() for i in range(0,len(psdca_res))])
pda_ = torch.tensor([pda_res[i].tolist() for i in range(0,len(psdca_res))])

torch.save(psdca_, f"record/{save_file_name}_psdca.pth")
torch.save(pda_, f"record/{save_file_name}_pda.pth")

psdca = psdca_[:,1,:]
psdca_mean = sum(psdca)/num_test
psdca_reg = psdca - psdca_mean
psdca_var = (sum(psdca_reg**2)/(num_test**2))**0.5

pda = pda_[:,1,:]
pda_mean = sum(pda)/num_test
pda_reg = pda - pda_mean
pda_var = (sum(pda_reg**2)/(num_test**2))**0.5

psdca_iters = psdca_res[0][2]
pda_iters = pda_res[0][2]

itermax = max(int(pda_iters[-2]), int(psdca_iters[-2]))

fig = plt.figure()
FONT_SIZE = 14
plt.rc('font',size=FONT_SIZE)

ax1 = fig.add_subplot(1,1,1)

ax1.set_xlabel( "Gradient evaluations", fontsize=16)
ax1.set_ylabel( "Test loss", fontsize=16 )

ax1.ticklabel_format(style='sci',axis='y',scilimits=(0,0))
ax1.ticklabel_format(style='sci',axis='x',scilimits=(0,0))

plt.setp(ax1.get_xticklabels(), fontsize=14)
plt.setp(ax1.get_yticklabels(), fontsize=14)

global_opt = 0

ax1.errorbar(pda_iters[:-2:10], pda_mean[:-2:10] - global_opt, pda_var[:-2:10],label = "PDA",ecolor= (0.27450980392156865, 0.5098039215686274, 0.7058823529411765,0.3), color = (0.27450980392156865, 0.5098039215686274, 0.7058823529411765,1.0))
ax1.errorbar(psdca_iters[:-2:10], psdca_mean[:-2:10] - global_opt, psdca_var[:-2:10],label = "P-SDCA",ecolor= (0.8627450980392157, 0.0784313725490196, 0.23529411764705882,0.3), color = (0.8627450980392157, 0.0784313725490196, 0.23529411764705882,1.0))

### Drawing the result of Ridge estimator
ridge_draw = []
for j, ridge_lambda in enumerate(ridge_lambdas):
    ridge_draw.append(sum(ridge_res[j])/num_test)
    ax1.plot(list(range(itermax+1)), [ridge_draw[-1]]*(itermax+1),  linestyle = "dashed", label = f"Ridge: λ={ridge_lambdas_str[j]}", color=ridge_color[j])

### Drawing the result of Nadaraya-Watson estimator
# Note that the denominator of the estimator may be zero, which results that the estimator becomes nan. 
# In this case, we don't include that bandwidth h in the experimental results. 
# We select three bandwidths for which the estimator does not become nan and plot them.
NW_draw = []
idx = 0
for j, h in enumerate(NW_bandwidth):
    NW_draw.append(sum(NW_res[j])/num_test)
    if math.isnan(sum(NW_res[j])/num_test):
        continue
    ax1.plot(list(range(itermax+1)), [sum(NW_res[j])/num_test]*(itermax+1),  linestyle = "dashed", label = f"Nadaraya-Watson: h={NW_bandwidth_str[j]}", color=NW_color[idx])
    idx += 1
    if idx >= 3:
        break

ax1.xaxis.offsetText.set_fontsize(14)
ax1.yaxis.offsetText.set_fontsize(14)

ymax = max(
        max(pda_mean[:-2:10] - global_opt + pda_var[:-2:10]), 
        max(psdca_mean[:-2:10] - global_opt + psdca_var[:-2:10]), 
        max(ridge_draw), max(NW_draw)
    )
ymin = min(
        min(pda_mean[:-2:10] - global_opt - pda_var[:-2:10]), 
        min(psdca_mean[:-2:10] - global_opt - psdca_var[:-2:10]), 
        min(ridge_draw), min(NW_draw)
    )
ax1.set_ylim([ymin,ymax])
ax1.set_xlim([0,itermax])

ax1.set_yscale('log')
ax1.legend(loc=u"upper right", prop={'size':12} )

fig.savefig(f"./{save_file_name}.pdf", bbox_inches="tight")