using ArgParse, JLD2, Printf, JSON, Dates, IterTools, Random;
using Distributed;

@everywhere include("runiteps.jl");
@everywhere include("helpers.jl");
@everywhere include("../binary_search.jl");
include("helpers_experiments.jl");

function parse_commandline()
    s = ArgParseSettings();

    @add_arg_table! s begin
        "--save_dir"
            help = "Directory for saving the experiment's data."
            arg_type = String
            default = "experiments/"
        "--data_dir"
            help = "Directory for loading the data."
            arg_type = String
            default = "data/"
        "--seed"
            help = "Seed."
            arg_type = Int64
            default = 42
        "--inst"
            help = "Instance considered."
            arg_type = String
            default = "2G"
        "--K"
            help = "Number of arms."
            arg_type = Int64
            default = 6
        "--expe"
            help = "Experiment considered."
            arg_type = String
            default = "epsTest"
        "--Nruns"
            help = "Number of runs of the experiment."
            arg_type = Int64
            default = 16
        "--wdeltas"
            help = "Run experiments with extended values for delta."
            action = :store_true
        "--history"
            help = "History of recommendation to keep: 'none' or 'partial'."
            arg_type = String
            default = "partial"
        "--freqHist"
            help = "Frequence of storing recommendations."
            arg_type = Int64
            default = 20
        "--eps"
            help = "Approximation parameter."
            arg_type = Float64
            default = 0.1
        "--Ie"
            help = "Number good arm for 2G instance."
            arg_type = Int64
            default = 2
        "--opt"
            help = "ϵ-optimality considered: `mul` or `add`."
            arg_type = String
            default = "add"
        "--Tmax"
            help = "Maximum time."
            arg_type = Int64
            default = 1000000
    end

    parse_args(s);
end

# Parameters
parsed_args = parse_commandline();
save_dir = parsed_args["save_dir"];
data_dir = parsed_args["data_dir"];
seed = parsed_args["seed"];
inst = parsed_args["inst"];
nK = parsed_args["K"];
expe = parsed_args["expe"];
Nruns = parsed_args["Nruns"];
wdeltas = parsed_args["wdeltas"];
history = parsed_args["history"];
freqHist = parsed_args["freqHist"];
ϵ = parsed_args["eps"];
Ie = parsed_args["Ie"];
opt = parsed_args["opt"];
Tmax = parsed_args["Tmax"];

# Storing parameters defining the instance
param_inst = Dict("inst" => inst, "nK" => nK, "Ie" => Ie, "data_dir" => data_dir);

# Associated β functions
δs = wdeltas ? [0.1, 0.01, 0.001] : [0.01];

# Naming files and folder
_ϵ = split(string(ϵ), ".")[2];
name_data = opt * "_e" * _ϵ * "_" * inst * (inst == "2G" ? string(Ie) : "") * "_K" * string(nK);
data_file = data_dir * name_data * ".dat";
now_str = Dates.format(now(), "dd-mm_HHhMM");
experiment_name = "exp_" * name_data * "_" * expe * "_" * history * (history == "partial" ? string(freqHist) : "") * (wdeltas ? "_delta" : "") * "_N" * string(Nruns);
experiment_dir = save_dir * now_str * ":" * experiment_name * "/";
mkdir(experiment_dir);
open("$(experiment_dir)parsed_args.json","w") do f
    JSON.print(f, parsed_args)
end

# For reproducibility, load the data if already defined.
if isfile(data_file)
    @load data_file dists μs Tstar wstar param_inst;
else
    @warn "Generating new data.";

    # Parameters
    μs, dists = get_instance_experiment(param_inst);

    # Oracle
    pep = EspilonBestArm(dists, ϵ, opt);
    Tstar, wstar = oracle(pep, μs);

    @save data_file dists μs Tstar wstar param_inst;
end
@save "$(experiment_dir)$(name_data).dat" data_file dists μs Tstar wstar param_inst;

# Get Tau_max
min_δ = minimum(δs);
lbd = (1 - 2 * min_δ) * log((1 - min_δ) / min_δ) * Tstar;
timeout_factor = occursin("GK16", expe) ? 15 : 60;
Tau_max = occursin("eq2nd", expe) ? timeout_factor * lbd : Tmax;

# Pure exploration problem
pep = EspilonBestArm(dists, ϵ, opt);

# Identification strategy used used on this instance: tuple (sr, rsp)
iss = everybody(expe, wstar);

# Run the experiments in parallel
@time data = pmap(
    ((is, i),) -> runit(seed + i, is, pep, μs, δs, Tau_max, history, freqHist),
    Iterators.product(iss, 1:Nruns)
);

# Save everything using JLD2.
@save "$(experiment_dir)$(experiment_name).dat" dists μs Tstar wstar pep iss data δs Nruns seed;

# Print a summary of the problem we considered
file = "$(experiment_dir)summary_$(experiment_name).txt";
print_eps_summary(pep, dists, μs, Tstar, wstar, δs, iss, data, Nruns, file);
