% synth_visualize_problem:
%
% Some generic visualization plots obtained from the run of a synthetic
% problem.
%
% Input:
%
% - path_input, path_output: paths where to find problem definitions and
%                            results
% - load_sorted_Z: flag indicating whether it is an adversarial or random
%                  setting
% - problem_number: which of the problems will be investigated
% - M: number of samples to generate prior information
% - v: which of the problem iterations to visualize

function synth_visualize_problem(path_input, path_output, load_sorted_Z, problem_number, M, v)

%% Load

names = dir(strcat(path_input, '/problem*.mat'));
name = names(problem_number).name;

file_name = strcat(path_input, '/', name);  
load(file_name, 'dat', 'model')
if load_sorted_Z
  fprintf('Using pre-sorted adversarial setting\n')
  file_name_output = strcat(path_output, '/adversarial_result_', name);
else
  file_name_output = strcat(path_output, '/result_', name);
end
load(file_name_output)

p = size(dat, 2) - 2; %#ok<*NODEF>
X_idx = p + 1;
Y_idx = p + 2;
n = size(dat, 1);

m_data  = mean(dat);
sd_data = sqrt(var(dat));
dat     = dat - repmat(m_data, n, 1);
dat     = dat ./ repmat(sd_data, n, 1);

default_xlim = [min(X_space) - 0.5 max(X_space) + 0.1];

%% Sample prior

[a, b, f_obs, theta] = prior_sample_affine_model(X_space, do_params_X_space, hyper_a, M);

%% Visualize prior features

figure
imagesc(do_params_X_space.K)
title('{\itK_{obs}}', 'FontSize', 20)

figure
ksdensity(theta.a_sf2)
title('Prior on amplitude', 'FontSize', 20)

figure
ksdensity(theta.a_ell)
title('Prior on lengthscale', 'FontSize', 20)

%% Visualize data

figure
scatter(X_exp_uniform{v}, Y_exp_uniform{v}); hold on  %#ok<*USENS>
plot(X_space, f_space, '-g', 'LineWidth', 5)
xlabel('Treatment {\it X}', 'FontSize', 20)
ylabel('Outcome {\it Y}', 'FontSize', 20)
title(['Interventional data ({\itM} = ', num2str(length(Y_exp_uniform{v})), ')'], 'FontSize', 20)
yl = ylim;
xlim(default_xlim)

figure
scatter(dat(:, X_idx), dat(:, Y_idx)); hold on
plot(X_space, f_space, '-g', 'LineWidth', 5)
plot(X_space, do_params_X_space.mu_do, '--r', 'LineWidth', 5)
plot(X_space, f_hat_confounded{2}, '-om', 'LineWidth', 5)
xlabel('Treatment {\it X}', 'FontSize', 20)
ylabel('Outcome {\it Y}', 'FontSize', 20)
title(['Observational data ({\itN} = ', num2str(length(dat)), ')'], 'FontSize', 20)
ylim(yl)
xlim([min(dat(:, X_idx)) - 0.5 max(dat(:, X_idx)) + 0.1]);


%% Visualize prior and inference

figure
plot(X_space, f_obs, 'Color', [0.4,0.4,0.4]); hold on
plot(X_space, f_space, '-g', 'LineWidth', 5)
plot(X_space, do_params_X_space.mu_do, '--r', 'LineWidth', 5)
xlabel('Treatment {\it X}', 'FontSize', 20)
ylabel('Outcome {\it Y}', 'FontSize', 20)
title('Prior: observational only', 'FontSize', 20)
yl = ylim;
xlim(default_xlim)

figure
plot(X_space, a, 'Color', [0.4,0.4,0.4]); hold on
plot(X_space, ones(length(X_space), 1), '--r', 'LineWidth', 5)
xlabel('Treatment {\it X}', 'FontSize', 20)
ylabel('Distortion {\it H}', 'FontSize', 20)
title('Prior: distortion only', 'FontSize', 20)
xlim(default_xlim)

figure
plot(X_space, b, 'Color', [0.4,0.4,0.4]); hold on
plot(X_space, zeros(length(X_space), 1), '--r', 'LineWidth', 5)
xlabel('Treatment {\it X}', 'FontSize', 20)
ylabel('Distortion {\it H}', 'FontSize', 20)
title('Prior: translation only', 'FontSize', 20)
xlim(default_xlim)

figure
plot(X_space, a .* f_obs + b, 'Color', [0.4,0.4,0.4]); hold on
plot(X_space, f_space, '-g', 'LineWidth', 5)
plot(X_space, do_params_X_space.mu_do, '--r', 'LineWidth', 5)
xlabel('Treatment {\it X}', 'FontSize', 20)
ylabel('Outcome {\it Y}', 'FontSize', 20)
title('Prior on dose-response', 'FontSize', 20)
ylim(yl)
xlim(default_xlim)

figure; plot(X_space, model1{v}.f_obs, 'Color', [0.4,0.4,0.4]); hold on
plot(X_space, f_space, '-g', 'LineWidth', 5)
plot(X_space, mean(model1{v}.f_obs, 2), '--r', 'LineWidth', 5)
xlabel('Treatment {\it X}', 'FontSize', 20)
ylabel('Outcome {\it Y}', 'FontSize', 20)
title('Posterior: observational only', 'FontSize', 20)
ylim(yl)
xlim(default_xlim)

figure; plot(X_space, model1{v}.a, 'Color', [0.4,0.4,0.4]); hold on
plot(X_space, mean(model1{v}.a, 2), '--r', 'LineWidth', 5)
xlabel('Treatment {\it X}', 'FontSize', 20)
ylabel('Outcome {\it Y}', 'FontSize', 20)
title('Posterior: distortion only', 'FontSize', 20)
xlim(default_xlim)
refline(0, 1)

figure; plot(X_space, model1{v}.b, 'Color', [0.4,0.4,0.4]); hold on
plot(X_space, mean(model1{v}.b, 2), '--r', 'LineWidth', 5)
xlabel('Treatment {\it X}', 'FontSize', 20)
ylabel('Outcome {\it Y}', 'FontSize', 20)
title('Posterior: translation only', 'FontSize', 20)
xlim(default_xlim)
refline(0, 0)

figure; plot(X_space, model1{v}.a .* model1{v}.f_obs + model1{v}.b, 'Color', [0.4,0.4,0.4]); hold on
plot(X_space, f_space, '-g', 'LineWidth', 5)
plot(X_space, mean(model1{v}.a .* model1{v}.f_obs + model1{v}.b, 2), '--r', 'LineWidth', 5)
xlabel('Treatment {\it X}', 'FontSize', 20)
ylabel('Outcome {\it Y}', 'FontSize', 20)
title('Posterior on dose-response', 'FontSize', 20)
ylim(yl)
xlim(default_xlim)

%% Example

% Example used in paper was synth_visualize_problem(path_input, path_output, false, 5, false, 5000, 3) for 'poly3_signal90' family of problems
% Notice that while the observational data can be recovered from the seed, the interventational data generated will be different.

