# =================================================================================================#
# Description: Produces the experimental results for the stability analysis
# Author: Ryan Thompson
# =================================================================================================#

import CSV, CUDA, DataFrames, Distributions, Flux, LinearAlgebra, ProgressMeter, Random, Statistics

using Distributed

Distributed.addprocs(2)

Distributed.@sync Distributed.@everywhere begin

include("Estimators/contextual_lasso.jl")

import CSV, CUDA, DataFrames, Distributions, Flux, LinearAlgebra, ProgressMeter, Random, Statistics

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    loss, n, p, m, mu_var, rho, s_min, s_max, id = par

    # Generate explanatory features
    Sigma = [rho ^ abs(i - j) for i in 1:p, j in 1:p]
    x_train = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))
    x_valid = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))
    x_test = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))

    # Generate contextual features
    z_train = rand(Distributions.Uniform(- 1, 1), n, m)
    z_valid = rand(Distributions.Uniform(- 1, 1), n, m)
    z_test = rand(Distributions.Uniform(- 1, 1), n, m)

    # Generate coefficients
    s = rand(Distributions.Uniform(s_min, s_max), p)
    c = rand(Distributions.Uniform(- 1, 1), p, m)
    r = map(j -> Statistics.quantile([LinearAlgebra.norm(z_train[i, :] - c[j, :], 2)
        for i in 1:n], s[j]), 1:p)
    beta(z, c, r) = (1 - 0.5 * LinearAlgebra.norm(z - c, 2) / r) * 
        (LinearAlgebra.norm(z - c, 2) <= r)
    beta_train = [beta(z_train[i, :], c[j, :], r[j]) for i in 1:n, j in 1:p]
    beta_valid = [beta(z_valid[i, :], c[j, :], r[j]) for i in 1:n, j in 1:p]
    beta_test = [beta(z_test[i, :], c[j, :], r[j]) for i in 1:n, j in 1:p]

    # Generate response
    mu_train = vec(sum(x_train .* beta_train, dims = 2))
    mu_valid = vec(sum(x_valid .* beta_valid, dims = 2))
    mu_test = vec(sum(x_test .* beta_test, dims = 2))

    kappa = sqrt(mu_var) / Statistics.std(mu_train, corrected = false)
    mu_train *= kappa
    mu_valid *= kappa
    mu_test *= kappa

    if loss == Flux.mse
        y_train = rand.(Distributions.Normal.(mu_train, 1))
        y_valid = rand.(Distributions.Normal.(mu_valid, 1))
        y_test = rand.(Distributions.Normal.(mu_test, 1))
    else
        y_train = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_train)))))
        y_valid = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_valid)))))
        y_test = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_test)))))
    end

    # Return generated data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test, beta_test

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, beta_hat_1, beta_hat_2, par)

    # Save scenario parameters
    loss, n, p, m, mu_var, rho, s_min, s_max, id = par

    # Compute stability metrics
    hamming_dist = sum((beta_hat_1 .!= 0) .⊻ (beta_hat_2 .!= 0)) / size(beta_hat_1, 1)

    # Update results
    push!(result, [estimator, hamming_dist, loss, n, p, m, mu_var, rho, s_min, s_max, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    CUDA.device!((Distributed.myid() - 1) % 2)

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], hamming_dist = [], loss = [], n = [], p = [], m = [], mu_var = [], rho = [], 
        s_min = [], s_max = [], id = []
    )

    # Generate data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test, 
        beta_test = gendata(par)

    # Save scenario parameters
    loss, n, p, m, _, _, _, _, _ = par
    
    # Set network configuration
    n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
    hidden_layers = repeat([n_neuron], 3)

    # Evaluate contextual lasso
    beta_hat_1 = ContextualLasso.coef(ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        verbose = false, intercept = false, relax = true,
        loss = loss, hidden_layers = hidden_layers
    ), z_test)

    # Evaluate contextual lasso
    beta_hat_2 = ContextualLasso.coef(ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        verbose = false, intercept = false, relax = true,
        loss = loss, hidden_layers = hidden_layers
    ), z_test)
    beta_hat_2

    evaluate!(result, "Contextual lasso", beta_hat_1, beta_hat_2, par)
    
    CUDA.reclaim()

    result

end

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = 
vcat(
    DataFrames.DataFrame(
        (loss = loss, n = n, p = p, m = m, mu_var = mu_var, rho = rho, s_min = s_min, s_max = s_max, 
            id = id) for
        loss = [Flux.mse],
        n = round.(Int, exp.(range(log(100), log(100000), 10))), # Number of samples
        p = 10, # Number of explanatory features
        m = 2, # Number of contextual features
        mu_var = 5, # Signal-to-noise ratio
        rho = 0.5, # Correlation coefficient
        s_min = 0.05, # Minimal sparsity level
        s_max = 0.15, # Maximal sparsity level
        id = 1:10 # Simulation run ID
    ),
    DataFrames.DataFrame(
        (loss = loss, n = n, p = p, m = m, mu_var = mu_var, rho = rho, s_min = s_min, s_max = s_max, 
            id = id) for
        loss = [Flux.mse],
        n = round.(Int, exp.(range(log(100), log(100000), 10))), # Number of samples
        p = 50, # Number of explanatory features
        m = [2, 5], # Number of contextual features
        mu_var = 5, # Signal-to-noise ratio
        rho = 0.5, # Correlation coefficient
        s_min = 0.05, # Minimal sparsity level
        s_max = 0.15, # Maximal sparsity level
        id = 1:10 # Simulation run ID
    )
)

# Run all simulations
# CUDA.jl is not reproducible with default rng
Distributed.@sync Distributed.@everywhere begin
    rng = Random.MersenneTwister((Distributed.myid() - 1) % 2)
    Random.default_rng() = rng
end
result = ProgressMeter.@showprogress pmap(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/stability.csv", result)

Distributed.rmprocs(Distributed.workers())