%% set up

clear;
addpath('../aux'); %test;
addpath('../figures');
addpath('../results'); %test;

design='NP'; %NP, SSG, HLLT
[f,sim,x_vis,f_vis]=get_design(design);

% sample size
N=1000;

% simulate data, split into stage 1 and 2 samples
rng('default');
[x,y,z]=sim(f,N);
[x1, x2, y1, y2, z1, z2]=split(x,y,z,.5);

%% sieve

df=get_basis(x,y,z,x_vis);

% initialize hyperparameters for tuning
lambda_0=log(0.05); %hyp1=lambda. log(0.05) for NP
xi_0=log(0.05); %hyp2=xi. log(0.05) for NP

% stage 1 tuning
sieve1_obj=@(lambda) sieve1_loss(df,exp(lambda)); %exp to ensure pos; fixed vx,vz
lambda_star=fminunc(sieve1_obj,lambda_0);

% stage 2 tuning
sieve2_obj=@(xi) sieve2_loss(df,[exp(lambda_star),exp(xi)]); %exp to ensure pos
xi_star=fminunc(sieve2_obj,xi_0);

% evaluate on full sample using tuned hyperparameters
y_vis=sieveIV_pred(df,[exp(lambda_star),exp(xi_star)],3); %exp to ensure pos

y_vis_noridge=sieveIV_pred(df,[0,0],3);

% visualize estimator
plot(x_vis,f_vis,'LineWidth',5);
hold on;
scatter(x,y,36,[.5 .5 .5]);
plot(x_vis,y_vis_noridge,'--g','LineWidth',5);
plot(x_vis,y_vis,'--r','LineWidth',5);
xlabel('x','FontSize',20)
ylabel('y','FontSize',20)
legend({'Structural function','Data','SieveIV','SieveIV(ridge)'},'Location','southeast','FontSize',20);
hold off;
saveas(gcf,fullfile('../figures',strcat('sieveIVridge_',design)),'epsc');

% visualize estimator
plot(x_vis,f_vis,'LineWidth',5);
hold on;
scatter(x,y,36,[.5 .5 .5]);
plot(x_vis,y_vis_noridge,'--r','LineWidth',5);
xlabel('x','FontSize',20)
ylabel('y','FontSize',20)
legend({'Structural function','Data','SieveIV'},'Location','southeast','FontSize',20);
hold off;
saveas(gcf,fullfile('../figures',strcat('sieveIV_',design)),'epsc');

%% sieve_ridge IV - 100 simulations

clear;
rng('default')
design='NP'; %NP, SSG, HLLT
[f,sim,x_vis,f_vis]=get_design(design);

% sample size
N=1000;
n_trials=100;
results=zeros(n_trials,N);
results_mse=zeros(n_trials,1);

for i=1:n_trials
    if mod(i,10)==0
        disp(num2str(i));
    end
    results(i,:)= sim_pred(design,N)';
    results_mse(i)=mse(results(i,:)',f_vis);
end

% visualize 100 samples
means=mean(results);
sorted=sort(results);
fifth=sorted(5,:);
nintetyfifth=sorted(95,:);

plot(x_vis,f_vis,'LineWidth',2);
hold on;
plot(x_vis,means,'--r','LineWidth',2);
plot(x_vis,fifth,':r','LineWidth',2);
plot(x_vis,nintetyfifth,':r','LineWidth',2);
xlabel('x');
ylabel('y');
hold off;
saveas(gcf,fullfile('../figures',strcat('sieveIVridge_',design,'_',num2str(n_trials))),'epsc');