%% Clearing
clear all;close all;clc


%% Settings
tf = 300; % final time
tspan = [eps tf]; %eps for numerical problems
%ic = [1,1,0,0]'; %initial condition
ic = [0.05,0.05,0,0]'; %initial condition

H =[2*1e-2,0*1e-3;0*1e-3,3*1e-4]; %hessian


%% Solving the ODE numerically
opts = odeset('RelTol',1e-5,'AbsTol',1e-10);
[t1,y1] = ode45(@(t,y) vanishing_damping(t,y,H), tspan, ic, opts);
[t2,y2] = ode45(@(t,y) maximum_damping(t,y,H), tspan, ic, opts);
[t3,y3] = ode45(@(t,y) critical_damping(t,y,H), tspan, ic, opts);
[t4,y4] = ode45(@(t,y) gd(t,y,H), tspan, ic, opts);

%% Plotting Dynamics
figure
semilogy(t4,H(1,1)*y4(:,1).^2+H(2,2)*y4(:,2).^2+H(1,2)*y4(:,1).*y4(:,2)+H(2,1)*y4(:,1).*y4(:,2),'Linewidth',3,'Color',[0.667 0.569 0.157]);hold on;
semilogy(t2,H(1,1)*y2(:,1).^2+H(2,2)*y2(:,2).^2+H(1,2)*y2(:,1).*y2(:,2)+H(2,1)*y2(:,1).*y2(:,2),'Linewidth',3,'Color',[0.938 0.566 0.208]);
semilogy(t3,H(1,1)*y3(:,1).^2+H(2,2)*y3(:,2).^2+H(1,2)*y3(:,1).*y3(:,2)+H(2,1)*y3(:,1).*y3(:,2),'Linewidth',3,'Color',[0.314 0.459 0.78]);
semilogy(t1,H(1,1)*y1(:,1).^2+H(2,2)*y1(:,2).^2+H(1,2)*y1(:,1).*y1(:,2)+H(2,1)*y1(:,1).*y1(:,2),'Linewidth',3,'Color',[0.694 0.259 1.0]);
legend('Gradient Flow','Nesterov Flow (damping $2\sqrt{\beta}$)','Nesterov Flow (damping $2\sqrt{\mu}$)','Nesterov Flow (damping $3/t$)','FontSize',20,'Interpreter','Latex','location','best')
xlabel('$t$','FontSize',20,'Interpreter','Latex')
ylabel('$f(X(t))$','FontSize',20,'Interpreter','Latex')
saveas(gcf,['nesterov_performance'],'epsc');


%% Functions
function dydt = vanishing_damping(t,y,H)
    dydt = 0*y;
    dydt(1:(end/2)) = y((end/2+1):end); % \dot X
    dydt((end/2+1):end) = - (3/t) * y((end/2+1):end) - 2*H * y(1:(end/2)); %\dot V    
end

function dydt = maximum_damping(t,y,H)
    dydt = 0*y;
    dydt(1:(end/2)) = y((end/2+1):end); % \dot X
    dydt((end/2+1):end) = - 2*sqrt(2*max(eig(H))) * y((end/2+1):end) - 2*H * y(1:(end/2)); %\dot V    
end

function dydt = critical_damping(t,y,H)
    dydt = 0*y;
    dydt(1:(end/2)) = y((end/2+1):end); % \dot X
    dydt((end/2+1):end) = - 2*sqrt(2*min(eig(H))) * y((end/2+1):end) - 2*H * y(1:(end/2)); %\dot V    
end

function dydt = gd(t,y,H)
    dydt = 0*y;
    dydt(1:(end/2)) = - 2*H * y(1:(end/2)); % \dot X
    dydt((end/2+1):end) = y((end/2+1):end); %\dot V    
end
