clear all
clc;
close all;

%% Loading Data
NS = 60000;
d  = 400;
[x,label] = readMNIST('train-images-idx3-ubyte', 'train-labels-idx1-ubyte', NS, 0);
n   = sum(label==0) + sum(label==8);
X   = zeros(n, d);
dig = zeros(n,1);

%%%% select two classes %%%%% 
j=1;
for i=1:NS
    if((label(i) == 0) || (label(i) == 8))
        dig(j) = label(i)/4-1;
        X(j,:) = reshape(x(:,:,i), [1,d]);
        X(j,:) = X(j,:)/norm(X(j,:));
        j = j+1;
    end
end

label_raw = dig;
[label,I] = sort(label_raw,'ascend') ;
X = X(I,:);

%% Loading Test Data
N_test = 10000;
d  = 400;
[x_test,label_raw] = readMNIST('t10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte', N_test, 0);
n_test   = sum(label_raw==0) + sum(label_raw==8);
X_test   = zeros(n_test, d);
label_test = zeros(n_test,1);

%%%% select two classes %%%%% 
j=1;
for i=1:N_test
    if((label_raw(i) == 0) || (label_raw(i) == 8))
        label_test(j) = label_raw(i)/4-1;
        X_test(j,:) = reshape(x_test(:,:,i), [1,d]);
        X_test(j,:) = X_test(j,:)/norm(X_test(j,:));
        j = j+1;
    end
end


%% Global Parameters
L  = 0.25; 
mu = L*10^(-3);
logi_reg = mu/2.;
numProcesses = 16;
grad_noise_std = 0;
%% Create the Network Structure

settingName = 'Circular';
%  settingName = 'Barbell';
% settingName = 'Connected';

[W_DS, Lap] = getNetworkWeights(settingName, numProcesses);
W = eye(numProcesses) - W_DS;
Weig = sort(eig(W));
lambda_max=Weig(numProcesses);
lambda_min=Weig(2);


% Distribute the data to the nodes
data_ix = cell(numProcesses,1);
temp_ix = 1;
temp_inc = ceil(n/numProcesses);
for p_ix = 1:numProcesses
    ix_end = min(temp_ix+temp_inc-1,n);
    data_ix{p_ix} = (temp_ix:ix_end);
    temp_ix = ix_end +1;
end

data_full = (1:n);
%% Problem Definition

obj_fn =  @(w,ix) ( sum(log(1+exp(-label(ix).*(X(ix,:)*w))))/length(ix) +  logi_reg * norm(w)^2 );
grad_fn_stoch = @(w,ix)( -X(ix,:)'*(label(ix)./(1+exp(label(ix).*(X(ix,:)*w))))/length(ix) + 2 * logi_reg * w );

%% Find the solution by SVRG
eta_svrg = 1./L;
n_svrg = 50;
batch_size = 40;
T = ceil((n/batch_size));
[x_svrg]=SVRG(obj_fn, grad_fn_stoch, n_svrg, zeros(d,1), eta_svrg, data_full, batch_size,T);
obj_svrg = obj_fn(x_svrg, data_full);
fprintf('SVRG last loss value: %.8f \n', obj_svrg)

x_optim = x_svrg;
f_optim = obj_svrg;


%% Initialization
x0 = randn(d,numProcesses);

%% Set the number of iterations
numIter = 20;
IterInner= 100;

%% IDEAL + AGD
rho= L/lambda_max;  
eta = 1./(lambda_max*rho+L); 
eta2 = rho + (mu/lambda_max);
kap_out= (L+rho*lambda_min)./(mu+rho*lambda_max)* lambda_max/lambda_min;
kap_in = ( L+rho*lambda_max)./mu;
beta_in = (sqrt(kap_in)-1)./(sqrt(kap_in)+1);
beta_out = (sqrt(kap_out)-1)./(sqrt(kap_out)+1); 

T_agd = IterInner;
n_agd = numIter;
[iter_AGD,err_AGD,AGD_end,test_acc_AGD] = AccGDSolver(obj_fn,grad_fn_stoch,numProcesses,n_agd,x0,W,T_agd,L,mu,eta,eta2,beta_in, beta_out,...
                                         lambda_min,lambda_max,rho,data_ix,x_optim,batch_size,data_full,X_test,label_test,n_test);
x_agd = mean(AGD_end,2);
obj_agd = obj_fn(x_agd, data_full);
fprintf('AGD last loss value: %.8f \n', obj_agd)

fprintf("AccGD Finished...\n");

%% MIDEAL + AGD
K = floor(sqrt(lambda_max/lambda_min));
W2 = ACCELGOSSIP(eye(numProcesses),W, K);

W2eig=sort(eig(W2));
lambda_max_W2=real(W2eig(numProcesses));
lambda_min_W2=real(W2eig(2));
fprintf('Graph condition number of W2: %.4f \n', lambda_max_W2/lambda_min_W2)

rho=  L/lambda_max_W2;  
eta = 1./(lambda_max_W2*rho+L);
eta2 = rho + (mu/lambda_max_W2);
kap_out= (L+rho*lambda_min_W2)./(mu+rho*lambda_max_W2)* lambda_max_W2/lambda_min_W2;
kap_in = ( L+rho*lambda_max_W2)./mu;
beta_in  = (sqrt(kap_in)-1)./(sqrt(kap_in)+1);
beta_out = (sqrt(kap_out)-1)./(sqrt(kap_out)+1);

T_magd = IterInner;
n_magd = numIter;
[iter_MAGD,err_MAGD,MAGD_end,test_acc_MAGD]= AccGDSolver(obj_fn,grad_fn_stoch,numProcesses,n_magd,x0,W2,T_magd,L,mu,eta,eta2,beta_in, beta_out,...
                                         lambda_min_W2,lambda_max_W2,rho,data_ix,x_optim,batch_size,data_full,X_test,label_test,n_test);

x_magd = mean(MAGD_end,2);
obj_magd = obj_fn(x_magd, data_full);
fprintf('MAGD last loss value: %.8f \n', obj_magd)


%% SSDA
rho= 0; 
eta = 1/(L);
eta2 = mu/lambda_max;
kap_in = L./mu;
kap_out= (L./mu)* (lambda_max/lambda_min);
beta_in = (sqrt(kap_in)-1)./(sqrt(kap_in)+1);
beta_out =(sqrt(kap_out)-1)./(sqrt(kap_out)+1); 

T_ssda = IterInner;
nb_itr = 2*numIter;  %10*numIter

[iter_SSDA,err_SSDA,SSDA_end,test_acc_SSDA]=AccGDSolver(obj_fn,grad_fn_stoch,numProcesses,nb_itr,x0,W,T_ssda,L,mu,eta,eta2,beta_in, beta_out,...
    lambda_min,lambda_max,rho,data_ix,x_optim,batch_size,data_full,X_test,label_test,n_test);

x_ssda = mean(SSDA_end,2);
obj_ssda = obj_fn(x_ssda, data_full);
fprintf('SSDA last loss value: %.8f \n', obj_ssda)


%% MSDA

rho=  0;  
eta = 1./L;
eta2 = (mu/lambda_max_W2);
kap_out= (L)./(mu)* (lambda_max_W2)/(lambda_min_W2);
kap_in = ( L)./mu;
beta_in  = (sqrt(kap_in)-1)./(sqrt(kap_in)+1);
beta_out = (sqrt(kap_out)-1)./(sqrt(kap_out)+1);

T_msda = IterInner;
nb_itr = 2*numIter; %6*numIter

[iter_MSDA,err_MSDA,MSDA_end,test_acc_MSDA]=AccGDSolver(obj_fn,grad_fn_stoch,numProcesses,nb_itr,x0,W2,T_msda,L,mu,eta,eta2,beta_in, beta_out,...
    lambda_min_W2,lambda_max_W2,rho,data_ix,x_optim,batch_size,data_full,X_test,label_test,n_test);

x_msda = mean(SSDA_end,2);
obj_msda = obj_fn(x_msda, data_full);
fprintf('MSDA last loss value: %.8f \n', obj_msda)

fprintf("MSDA Finished ... \n");

%% APM_C
numIter_AMPC= IterInner; 
[iter_APMC,err_APMC,APMC_end,test_acc_APMC]=APM_C(obj_fn,grad_fn_stoch,numProcesses,1.-lambda_min,numIter_AMPC,IterInner,x0,W_DS,L,mu,...
                                    data_ix,x_optim,batch_size,data_full,X_test,label_test,n_test);

x_apmc = mean(APMC_end,2);
obj_apmc = obj_fn(x_apmc, data_full);
fprintf('APMC last loss value: %.8f \n', obj_apmc)

fprintf("APM_C Finished...\n");


%% EXTRA
alpha= 1/L;
numIter_Extra=numIter*IterInner;

[iter_EXTRA,err_EXTRA,EXTRA_end,test_acc_EXTRA]=EXTRA(obj_fn,grad_fn_stoch,numProcesses,numIter_Extra,W_DS,x0,alpha, ...
                                        data_ix,x_optim,batch_size, data_full,X_test,label_test,n_test);
                                    
x_extra = mean(EXTRA_end,2);
obj_extra = obj_fn(x_extra, data_full);
fprintf('EXTRA last loss value: %.8f \n', obj_extra)                                    
fprintf("EXTRA Finished ...\n");

%% Save the outputs
savefilename = sprintf("%s_output.mat", settingName);
save(savefilename, 'f_optim','K', 'L','mu','lambda_max','lambda_min','numProcesses',...
            'T_agd','iter_AGD','err_AGD','AGD_end','test_acc_AGD', ...
            'T_magd','iter_MAGD','err_MAGD','MAGD_end','test_acc_MAGD',...
            'T_ssda','iter_SSDA','err_SSDA','SSDA_end','test_acc_SSDA',...
            'T_msda','iter_MSDA','err_MSDA','MSDA_end','test_acc_MSDA',...
            'IterInner','iter_APMC','err_APMC','APMC_end','test_acc_APMC', ...
            'iter_EXTRA','err_EXTRA','EXTRA_end','test_acc_EXTRA')

%% Print function value 
close all
figPos = [100, 100, 800, 600];
figure('Position', figPos);
set(gcf,'color','w');
tau= 0.1; 

if tau == 10
    xlim([0 18000])
    ylim([-20,0])
elseif tau == 1
    xlim([0 4000])
    ylim([-20,0])
elseif tau == 0.1 
%%%% For small tau, 
%%%% increase the iterations of SSDA/MSDA to make the x-axis aligned
    xlim([0 2000])
    ylim([-20,0])
end

hold on

ax_AGD=(0:length(iter_AGD)-1)*T_agd*(1+tau);
plot(ax_AGD,log(iter_AGD-f_optim),'LineWidth',3);


ax_MAGD=(0:length(iter_MAGD)-1)*T_magd*(1+K*tau);
plot(ax_MAGD,log(iter_MAGD-f_optim),'LineWidth',3);

ax_APMC=(0:length(iter_APMC)-1)*(1+numIter_AMPC*tau);
plot(ax_APMC,log(iter_APMC-f_optim),'LineWidth',3);

ax_SSDA=(0:length(iter_SSDA)-1)*(T_ssda+tau);
plot(ax_SSDA,log(iter_SSDA-f_optim),'LineWidth',3);


ax_MSDA=(0:length(iter_MSDA)-1)*(T_msda+K*tau);
plot(ax_MSDA,log(iter_MSDA - f_optim),'LineWidth',3);

ax_EXTRA=(0:length(iter_EXTRA)-1)*(1+tau);
plot(ax_EXTRA,log(iter_EXTRA-f_optim),'LineWidth',3);



xlabel('Time')
ylabel('$$\log(f(x_k) -f^*)$$','Interpreter','Latex');
legend("IDEAL+AGD","MIDEAL+AGD","APM-C","SSDA+AGD","MSDA+AGD","EXTRA",'Location','SouthWest','FontSize',25) %'NorthEast'

tit=sprintf("%s Graph (n= %1.0f, \\tau = %1.1f, \\kappa_f = %1.0f, \\kappa_W = %1.1f)",settingName, numProcesses,tau, L./mu, lambda_max/lambda_min);
set(gca,'FontSize', 30);
title(tit,'FontSize', 25)

grid on
grid minor


filename = sprintf("./figures/%sGraphtau=%1.e", settingName, tau);
saveas(gcf, filename,'epsc')


%% Print test accuracy
loadfilename = sprintf("%s_output.mat", settingName);
load(loadfilename, 'f_optim','K', 'L','mu','lambda_max','lambda_min','numProcesses',...
            'T_agd','iter_AGD','err_AGD','AGD_end','test_acc_AGD', ...
            'T_magd','iter_MAGD','err_MAGD','MAGD_end','test_acc_MAGD',...
            'T_ssda','iter_SSDA','err_SSDA','SSDA_end','test_acc_SSDA',...
            'T_msda','iter_MSDA','err_MSDA','MSDA_end','test_acc_MSDA',...
            'IterInner','iter_APMC','err_APMC','APMC_end','test_acc_APMC', ...
            'iter_EXTRA','err_EXTRA','EXTRA_end','test_acc_EXTRA')

close all
figPos = [100, 100, 800, 600];
figure('Position', figPos);
set(gcf,'color','w');
tau= 0.1; 

if tau == 10
    xlim([0 18000])
    ylim([0.6,1.])
elseif tau == 1
    xlim([0 4000])
    ylim([0.6,1.])
elseif tau == 0.1 
%%%% For small tau, 
%%%% increase the iterations of SSDA/MSDA to make the x-axis aligned
    xlim([0 2000])
    ylim([0.6,1.])
end

hold on

ax_AGD=(0:length(iter_AGD)-1)*T_agd*(1+tau);
plot(ax_AGD,test_acc_AGD,'LineWidth',3);


ax_MAGD=(0:length(iter_MAGD)-1)*T_magd*(1+K*tau);
plot(ax_MAGD,test_acc_MAGD,'LineWidth',3);

ax_APMC=(0:length(iter_APMC)-1)*(1+numIter_AMPC*tau);
plot(ax_APMC,test_acc_APMC,'LineWidth',3);

ax_SSDA=(0:length(iter_SSDA)-1)*(T_ssda+tau);
plot(ax_SSDA,test_acc_SSDA,'LineWidth',3);


ax_MSDA=(0:length(iter_MSDA)-1)*(T_msda+K*tau);
plot(ax_MSDA,test_acc_MSDA,'LineWidth',3);

ax_EXTRA=(0:length(iter_EXTRA)-1)*(1+tau);
plot(ax_EXTRA,test_acc_EXTRA,'LineWidth',3);



xlabel('Time')
ylabel('Test accuracy','Interpreter','Latex');
legend("IDEAL+AGD","MIDEAL+AGD","APM-C","SSDA+AGD","MSDA+AGD","EXTRA",'Location','SouthWest','FontSize',25) %'NorthEast'

tit=sprintf("%s Graph (n= %1.0f, \\tau = %1.1f, \\kappa_f = %1.0f, \\kappa_W = %1.1f)",settingName, numProcesses,tau, L./mu, lambda_max/lambda_min);
set(gca,'FontSize', 30);
title(tit,'FontSize', 25)

grid on
grid minor


filename = sprintf("./figures/%s_test_acc_Graphtau=%1.e", settingName, tau);
saveas(gcf, filename,'epsc')
