% hmc_sample_standard_deep_gp:

function [theta_mcmc, f_mcmc, f_obs_mcmc] = ...
         hmc_sample_standard_deep_gp(X_exp, Y_exp, X_space, do_params_X_space, ...
                                     prior_info, hyper_a, X_obs, Z_obs, burn_in, M, verbose)

%% Preliminaries

if nargin < 11, verbose = false; end

S = length(X_exp);
theta_mcmc = cell(S, 1);
f_obs_mcmc = cell(S, 1);
num_X_space = length(X_space);

m = [hyper_a.sf2.mu; hyper_a.ell.mu; 0];
v = [hyper_a.sf2.var; hyper_a.ell.var; 1];
noise_matrix = get_noise_matrix(num_X_space);

%% Run sampler

for s = 1:S

  if verbose, fprintf('HMC Deep GP, stratum %d of %d\n', s, S); end
  
  n_exp = length(Y_exp{s});
  A = multi_exp_mapping(X_space, X_exp{s}, X_obs{s}, Z_obs{s}, prior_info{s});
  if size(A, 2) ~= num_X_space
    error('Currently requires at least one observation per dose level')
  end
  m(3) = log(do_params_X_space{s}.var_err_do);
  
  stan_dat = struct('N', n_exp, 'N_s', size(A, 2), 'm', m, 'v', v, 'A', A, ...
                    'noise_matrix', noise_matrix, ...
                    'mu_do', do_params_X_space{s}.mu_do, ...
                    'Sigma_do', do_params_X_space{s}.K_scale, 'y', Y_exp);
  stan_init = struct('theta', m, 'f_obs', do_params_X_space{s}.mu_do);
  fit = stan('file', 'sample/deep.stan', 'data', stan_dat, 'iter', M, 'chains', 1, ...
             'init', stan_init, 'warmup', burn_in, 'refresh', 1, 'verbose', verbose);
  while is_running(fit); end       
 
  theta_mcmc{s} = fit.extract('permuted', true).theta;
  f_obs_mcmc{s} = fit.extract('permuted', true).f_obs;
  
end

%% Generate f samples

f_mcmc = cell(S, 1);
M_burned = length(theta_mcmc{1}(:, 1));

for s = 1:S
    
  A = multi_exp_mapping(X_space, X_exp{s}, X_obs{s}, Z_obs{s}, prior_info{s});
  f_mcmc{s} = zeros(M_burned, num_X_space);
  n_exp = length(Y_exp{s});
  
  for m = 1:M_burned
    sf2  = exp(theta_mcmc{s}(m, 1));
    ell  = exp(theta_mcmc{s}(m, 2));
    likv = exp(theta_mcmc{s}(m, 3));
    prior_mu = f_obs_mcmc{s}(m, :)';
    R_m = prior_mu(:, ones(1, num_X_space));
    prior_cov = sf2 * exp(-0.5 * (R_m - R_m').^2 / ell) + noise_matrix;
    inv_post_cov = inv(prior_cov) + A' * diag(ones(n_exp, 1) / likv) * A;
    post_mean = inv_post_cov \ ((A' * diag(ones(n_exp, 1) / likv) * Y_exp{s}));
    f_mcmc{s}(m, :) = (post_mean + chol(inv_post_cov) \ randn(num_X_space, 1))';
  end
  
end
