"""
Usage:
    python3 -m dp_join.experiment.adult_logistic_regression [out_dir]

Experimentally evaluates logistic regression on private joins with a range of
parameters, using the Adult dataset.

Writes output to the directory at out_dir as a set of plots in both png and
pdf formats, with an index.html file gathering the results. Also saves csv files
containing the values that were plotted. If out_dir isn't specified, a name will
be invented.

If out_dir already exists and has CSV files in it, this script will redraw the
plots without rerunning the experiments that have already been done. The
intended use of this is to be able to modify the plotting code and quickly
regenerate the plots without rerunning time-consuming experiments. This only
works if experiment code hasn't changed too much since the CSV files were
generated: specifically, we assume the new and old code use the same name for
the CSV file corresponding to each plot, and that the meaning of each CSV column
is the same as before. Also, the repetitions_per_experiment variable should not
have changed, otherwise the y axis label will be inaccurate.
"""

from collections import namedtuple
from concurrent.futures import ProcessPoolExecutor
import math
from matplotlib import pyplot
import numpy
import numpy.random
import os
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import sys

from .. import adult
from ..hash import salted_sha256
from ..sketch.one_hot import OneHotSketcher
from ..timing import fmttime

# Constants
repetitions_per_experiment = 17
plot_dpi = 72
epsilons = (0.01, 0.1, 0.5, 1.0, 2.0)
report_header = """\
<!DOCTYPE html>
<html>
<body>
<p>"""
report_footer = "</p>\n</body>\n</html>\n"

# Common inputs used by the experiment code.
Resources = namedtuple(
    "Resources",
    ( # Used to make independent random number generators for different
      # processes.
      "seed_sequence",
      # The performance of a model trained without privacy.
      "nonprivate_test_accuracy",
      "html_out",
      "training_set",
      "test_set",
      "out_dir",
    ))

ExperimentParameters = namedtuple(
    "ExperimentParameters",
    ( # Differential privacy parameter. None means no privacy.
      "epsilon",
      "sketch_dimension",
      # Where along the x axis should this result appear?
      "x_axis_value",
      # Text label to appear along the x axis. E.g. if x_axis_value is 500000,
      # this could be the string "500K".
      "x_axis_label",
      # Default: True.
      "use_difference_sketch",
      # The number of people who are on both sides of the join. The default of
      # None means use all rows.
      "num_rows_both_sides",
      # The number of people who are on just the feature side of the join.
      # Default 0.
      "num_rows_just_features",
      # The number of people who are on just the label side of the join.
      # Default 0.
      "num_rows_just_labels",
    ),
    defaults = (True, None, 0, 0),
)

def int_to_short_string(n):
    """
    Converts n to a string, using a suffix "M" or "K" if possible without
    losing precision.
    """
    if n % 1_000_000 == 0:
        return str(n // 1_000_000) + "M"
    elif n % 1000 == 0:
        return str(n // 1000) + "K"
    else:
        return str(n)

def new_pipeline(**logistic_regression_kwargs):
    return Pipeline(
        steps = (
            ("transform", StandardScaler()),
            ("logistic_regression",
             LogisticRegression(penalty = "none",
                                **logistic_regression_kwargs))))

def train_model_from_difference_sketch(identities, features, label_sketch,
                                       sketcher):
    """
    Reconstructs labels using label_sketch and trains a model. Returns a
    trained sklearn pipeline.
    """
    reconstructed_features = []
    reconstructed_labels = []
    weights = []
    estimates = (
        sketcher.estimate_training_labels_from_difference_sketch(
            sketch = label_sketch,
            identities = identities))
    for i, features_row in enumerate(features):
        estimate = estimates[i]
        if estimate == 0:
            continue
        reconstructed_features.append(features_row)
        if estimate > 0:
            label = 0
        else:
            label = 1
        reconstructed_labels.append(label)
        weights.append(abs(estimate))
    pipeline = new_pipeline()
    pipeline.fit(
        reconstructed_features,
        reconstructed_labels,
        **{"logistic_regression__sample_weight": weights})
    return pipeline

def train_model_from_non_difference_sketch(identities, features, label_sketch,
                                           sketcher):
    """
    Like train_model_from_difference_sketch, but label_sketch is not a
    difference sketch. Passes a large max_iter parameter to LogisticRegression,
    but it fails to converge anyway.
    """
    reconstructed_features = []
    reconstructed_labels = []
    reconstructed_weights = []
    weight_matrix = (
        sketcher.estimate_training_weights(
            sketch = label_sketch,
            identities = identities,
            num_categories = 2))
    for i, features_row in enumerate(features):
        for label in range(2):
            weight = weight_matrix[i, label]
            if weight == 0:
                continue
            reconstructed_features.append(features_row)
            reconstructed_labels.append(label)
            reconstructed_weights.append(weight)
    pipeline = new_pipeline(max_iter = 100_000)
    pipeline.fit(
        reconstructed_features,
        reconstructed_labels,
        **{"logistic_regression__sample_weight": reconstructed_weights})
    return pipeline

def pick_training_examples(training_set, test_set, rng, parameters):
    """
    Parameters should be an ExperimentParameters.

    Returns ((feature_ids, features), (label_ids, labels)), where feature_ids
    are the (made-up) identities of the rows with the corresponding features,
    and same for labels.

    feature_ids and label_ids may be different (and have different lengths) e.g.
    if parameters.num_rows_just_features is nonzero. However, the first
    parameters.num_rows_both_sides elements of the two lists will match.
    """

    # Invent some identities to join on.
    identities = tuple(f"Identity #{i}" for i in range(len(training_set)))

    if parameters.num_rows_both_sides is None:
        all_features, all_labels = zip(*training_set)
        return (identities, all_features), (identities, all_labels)

    order = rng.permutation(len(training_set))
    features_in_order, labels_in_order = zip(*(
        training_set[i] for i in order))
    identities_in_order = tuple(identities[i] for i in order)

    num_feature_rows = (
        parameters.num_rows_both_sides + parameters.num_rows_just_features)
    return (
        (identities_in_order[:num_feature_rows],
         features_in_order[:num_feature_rows]),
        (identities_in_order[:parameters.num_rows_both_sides] +
         identities_in_order[
             num_feature_rows:
             num_feature_rows+parameters.num_rows_just_labels
         ],
         labels_in_order[:parameters.num_rows_both_sides] +
         labels_in_order[
             num_feature_rows:
             num_feature_rows+parameters.num_rows_just_labels
         ]))

def do_nonprivate_experiment_once(training_set, test_set, num_rows_both_sides,
                                  rng = None):
    """
    Implements do_experiment_once when there is no privacy, i.e.
    parameters.epsilon is None.

    If num_rows_both_sides is None, uses the full training set, in which case
    rng is not needed.
    """
    if num_rows_both_sides is None:
        training_features, training_labels = zip(*training_set)
    else:
        order = rng.permutation(len(training_set))
        training_features, training_labels = zip(
            *(training_set[i] for i in order[:num_rows_both_sides]))
    pipeline = new_pipeline()
    pipeline.fit(training_features, training_labels)
    test_features, test_labels = zip(*test_set)
    return pipeline.score(test_features, test_labels)

def do_experiment_once(training_set, test_set, parameters, random_seed):
    """
    Does a single experiment, meaning: simulate a two-party setting by making
    a joinable private sketch of the training labels and then training a model
    using that sketch and the training features.

    Returns the accuracy of the trained model on the test set, evaluated using
    the real features and labels (i.e. not using the label sketch).
    """

    rng = numpy.random.default_rng(random_seed)

    if parameters.epsilon is None:
        return do_nonprivate_experiment_once(training_set, test_set,
                                             parameters.num_rows_both_sides,
                                             rng)

    # Randomize the hash function by appending a random string to the input.
    hash_salt = str(rng.integers(2**31))
    sketcher = OneHotSketcher(
        hash_function = salted_sha256(hash_salt),
        rng = rng,
    )

    ((train_feature_identities, train_features),
     (train_label_identities, train_labels),
    ) = pick_training_examples(
        training_set = training_set,
        test_set = test_set,
        rng = rng,
        parameters = parameters,
    )
    test_features, test_labels = zip(*test_set)

    label_sketch = sketcher.sketch_values(
        epsilon = parameters.epsilon,
        identities = train_label_identities,
        num_categories = 2,
        values = train_labels,
        num_buckets = parameters.sketch_dimension,
        difference_sketch = parameters.use_difference_sketch,
    )

    if parameters.use_difference_sketch:
        pipeline = train_model_from_difference_sketch(
            identities = train_feature_identities,
            features = train_features,
            label_sketch = label_sketch,
            sketcher = sketcher,
        )
    else:
        pipeline = train_model_from_non_difference_sketch(
            identities = train_feature_identities,
            features = train_features,
            label_sketch = label_sketch,
            sketcher = sketcher,
        )

    return pipeline.score(test_features, test_labels)

def do_experiment_repeated(resources, parameters):
    """
    Calls do_experiment_once multiple times in parallel, and returns a sequence
    of the test accuracies measured.
    """
    random_seeds = resources.seed_sequence.spawn(repetitions_per_experiment)
    experiment_inputs = (
        ((resources.training_set, resources.test_set, parameters, seed)
        for seed in random_seeds))
    with ProcessPoolExecutor(max_workers = os.cpu_count()) as executor:
        return tuple(executor.map(do_experiment_once, *zip(*experiment_inputs)))

def do_experiment_series(resources, experiments, plot_name, csv_path):
    """
    experiments should be an iterable of ExperimentParameters objects. Does
    those experiments using do_experiment_repeated and writes the results in CSV
    format to a file at csv_path.
    """
    print(f"Running experiments: {plot_name}.")
    with open(csv_path, "w") as out:
        for parameters in experiments:
            print(f"{fmttime()} Trying {parameters.x_axis_value}.")
            test_accuracies = do_experiment_repeated(resources, parameters)
            accuracy_quantiles = numpy.quantile(
                a = test_accuracies,
                q = (0, 0.25, 0.5, 0.75, 1),
            )
            csv_row = (
                (parameters.x_axis_label, parameters.x_axis_value) +
                tuple(accuracy_quantiles))
            out.write(",".join(map(str, csv_row)))
            out.write("\n")

def add_plot_html(resources, plot_name):
    """
    Writes HTML showing a new plot.
    """
    resources.html_out.write(f"<h2>{plot_name}</h2>\n")
    resources.html_out.write(f'<p><img src="{plot_name}.png"></p>\n')
    resources.html_out.write(f'<p><a href="{plot_name}.pdf">[as pdf]</a></p>\n')

def add_experiment_plot(resources, axis_kwargs, plot_name, csv_path,
                        plot_nonprivate_line):
    """
    Read experiment results from a CSV file at csv_path. Plot them to
    {plot_name}.pdf and {plot_name}.png, and add the plots to the HTML report.
    """
    add_plot_html(resources, plot_name)
    with open(csv_path) as csv_file:
        rows = (line.split(",") for line in csv_file)
        columns = tuple(zip(*rows))
        x_labels = columns[0]
        x, q0, q25, q50, q75, q100 = (
            numpy.array(tuple(map(float, column))) for column in columns[1:])
    figure = pyplot.figure(figsize = (6.4, 3))
    figure.add_subplot(
        xticks = x,
        xticklabels = x_labels,
        ylabel = f"test accuracy (quantiles over {repetitions_per_experiment} "
            + "runs)",
        **axis_kwargs,
    )
    pyplot.plot(x, q100, ":k", label = "max")
    pyplot.errorbar(
        x, q50,
        yerr = (
            q50 - q25,
            q75 - q50,
        ),
        fmt = "k",
        label = "0.25/0.5/0.75 quantile",
    )
    pyplot.plot(x, q0, ":k", label = "min")
    if plot_nonprivate_line:
        pyplot.axhline(y = resources.nonprivate_test_accuracy, color = "k",
                       linestyle = "--", label = "without privacy")
    pyplot.legend()
    pyplot.tight_layout()
    pyplot.savefig(f"{resources.out_dir}/{plot_name}.png", dpi = plot_dpi)
    pyplot.savefig(f"{resources.out_dir}/{plot_name}.pdf")

def plot_experiment_series(resources, experiments, plot_name, axis_kwargs,
                           plot_nonprivate_line = True):
    """
    Combines do_experiment_series and add_experiment_plot. Skips running
    do_experiment_series if the CSV output already exists.
    """
    csv_path = f"{resources.out_dir}/{plot_name}.csv"
    if not os.path.exists(csv_path):
        do_experiment_series(resources, experiments, plot_name, csv_path)
    add_experiment_plot(
        resources = resources,
        axis_kwargs = axis_kwargs,
        plot_name = plot_name,
        csv_path = csv_path,
        plot_nonprivate_line = plot_nonprivate_line,
    )

def add_nondifference_vary_epsilon_plots(resources):
    plot_name = "nondifference_vary_epsilon"
    sketch_dimension = 500_000
    experiments = (
        ExperimentParameters(
            epsilon = epsilon,
            sketch_dimension = sketch_dimension,
            x_axis_value = epsilon,
            x_axis_label = epsilon,
            use_difference_sketch = False,
        )
        for epsilon in epsilons)
    plot_experiment_series(
        resources = resources,
        experiments = experiments,
        plot_name = plot_name,
        axis_kwargs = {
            "title":
                "Test accuracy vs. epsilon "
                "(non-difference sketch, dimension=500K)",
            "xlabel": "epsilon",
            "xscale": "log",
        },
    )

def add_vary_epsilon_plots(resources):
    plot_name = "vary_epsilon"
    sketch_dimension = 500_000
    experiments = (
        ExperimentParameters(
            epsilon = epsilon,
            sketch_dimension = sketch_dimension,
            x_axis_value = epsilon,
            x_axis_label = epsilon,
        )
        for epsilon in epsilons)
    plot_experiment_series(
        resources = resources,
        experiments = experiments,
        plot_name = plot_name,
        axis_kwargs = {
            "title": "Test accuracy vs. epsilon (dimension=500K)",
            "xlabel": "epsilon",
            "xscale": "log",
        },
    )

def add_vary_sketch_dimension_plots(resources):
    plot_name = "vary_sketch_dimension"
    epsilon = 1.0
    sketch_dimensions = (1, 5, 50, 500, 5000, 50_000, 500_000, 5_000_000)
    experiments = (
        ExperimentParameters(
            epsilon = epsilon,
            sketch_dimension = sketch_dimension,
            x_axis_value = sketch_dimension,
            x_axis_label = int_to_short_string(sketch_dimension),
        )
        for sketch_dimension in sketch_dimensions)
    plot_experiment_series(
        resources = resources,
        experiments = experiments,
        plot_name = plot_name,
        axis_kwargs = {
            "title": "Test accuracy vs. sketch dimension (epsilon=1)",
            "xlabel": "sketch dimension",
            "xscale": "log",
        },
    )

def add_nonprivate_vary_num_examples_plots(resources):
    plot_name = "nonprivate_vary_num_examples"
    sketch_dimension = 500_000
    nums_examples = (100, 1000, 5000, 15000, 32561)
    num_all_examples = len(resources.training_set)
    experiments = (
        ExperimentParameters(
            epsilon = None,
            sketch_dimension = sketch_dimension,
            num_rows_both_sides = num_examples,
            x_axis_value = num_examples,
            x_axis_label = int_to_short_string(num_examples),
        )
        for num_examples in nums_examples)
    title = "Test accuracy vs. # examples (no privacy)"
    plot_experiment_series(
        resources = resources,
        experiments = experiments,
        plot_name = plot_name,
        axis_kwargs = {
            "title": title,
            "xlabel": "number of training examples used",
            "xscale": "log",
        },
        plot_nonprivate_line = False,
    )

def add_features_only_plots(resources):
    # Some people are only present in the features dataset.
    plot_name = "features_only"
    num_all_examples = len(resources.training_set)
    epsilon = 1
    sketch_dimension = 500_000
    # Number of examples that appear on both sides of the join.
    num_examples_both_sides = math.floor(0.5 * num_all_examples)
    # Fraction of examples that appear only on the features side.
    features_only_fractions = (0, 0.1, 0.25, 0.5)

    experiments = []
    for features_only_fraction in features_only_fractions:
        # The x axis is size of party R (features)'s dataset divided by size of
        # join.
        x_axis_value = 1.0 / (1.0 - features_only_fraction)
        experiments.append(ExperimentParameters(
            epsilon = epsilon,
            sketch_dimension = sketch_dimension,
            num_rows_both_sides = num_examples_both_sides,
            num_rows_just_features =
                math.ceil(features_only_fraction * num_all_examples),
            x_axis_value = x_axis_value,
            x_axis_label = f"{x_axis_value:.2f}",
        ))
    plot_experiment_series(
        resources = resources,
        experiments = experiments,
        plot_name = plot_name,
        axis_kwargs = {
            "title": "Test accuracy vs. # extra rows in $D_R$ (dimension=500K)",
            "xlabel": "$|D_R| / |D_R ⋈ D_S|$",
        },
    )

def add_labels_only_plots(resources, sketch_dimension):
    # Some people are only present in the labels dataset.
    plot_name = f"labels_only_dim_{sketch_dimension}"
    num_all_examples = len(resources.training_set)
    epsilon = 1
    # Number of examples that appear on both sides of the join.
    num_examples_both_sides = math.floor(0.5 * num_all_examples)
    # Fraction of examples that appear only on the labels side.
    labels_only_fractions = (0, 0.1, 0.25, 0.5)

    experiments = []
    for labels_only_fraction in labels_only_fractions:
        # The x axis is size of party R (features)'s dataset divided by size of
        # join.
        x_axis_value = 1.0 / (1.0 - labels_only_fraction)
        experiments.append(ExperimentParameters(
            epsilon = epsilon,
            sketch_dimension = sketch_dimension,
            num_rows_both_sides = num_examples_both_sides,
            num_rows_just_features = 0,
            num_rows_just_labels =
                math.ceil(labels_only_fraction * num_all_examples),
            x_axis_value = x_axis_value,
            x_axis_label = f"{x_axis_value:.2f}",
        ))
    plot_experiment_series(
        resources = resources,
        experiments = experiments,
        plot_name = plot_name,
        axis_kwargs = {
            "title": "Test accuracy vs. # extra label rows (d = "
                + f"{sketch_dimension})",
            "xlabel": "fraction of training set on labels side only",
            "xlabel": "$|D_S| / |D_R ⋈ D_S|$",
        },
    )

if __name__ == "__main__":
    if len(sys.argv) == 1:
        out_dir = f"out/adult_logistic_regression_{fmttime()}"
        os.makedirs(out_dir)
    elif len(sys.argv) == 2:
        out_dir = sys.argv[1]
        os.makedirs(out_dir, exist_ok = True)
    else:
        raise RuntimeError("Too many command-line arguments.")

    report_html_path = f"{out_dir}/index.html"
    print(f"Writing report to {report_html_path}.")

    with open(report_html_path, "w") as html_out:
        with open("datasets/adult/adult.data") as training_file:
            training_set = adult.read_as_features_labels(training_file)
        with open("datasets/adult/adult.test") as test_file:
            test_set = adult.read_as_features_labels(test_file)
        html_out.write(report_header)
        nonprivate_test_accuracy = do_nonprivate_experiment_once(
            training_set, test_set, num_rows_both_sides = None)
        html_out.write(
            f"<p>Nonprivate test accuracy: {nonprivate_test_accuracy}.</p>\n")
        resources = Resources(
            seed_sequence = numpy.random.SeedSequence(),
            training_set = training_set,
            test_set = test_set,
            nonprivate_test_accuracy = nonprivate_test_accuracy,
            html_out = html_out,
            out_dir = out_dir,
        )
        add_nondifference_vary_epsilon_plots(resources)
        add_vary_epsilon_plots(resources)
        add_vary_sketch_dimension_plots(resources)
        add_nonprivate_vary_num_examples_plots(resources)
        add_features_only_plots(resources)
        add_labels_only_plots(resources, sketch_dimension = 50)
        add_labels_only_plots(resources, sketch_dimension = 500_000)
        html_out.write(report_footer)
    print(f"Done. See output at {report_html_path}.")
