# plot eta for article
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from active_ranking import config
from active_ranking.experiments.experiment import Experiment, models
from active_ranking.experiments.plotting import plot
from active_ranking.metrics import roc
from active_ranking.scenarios.complexity import problem_complexity
from active_ranking.scenarios.eta import plot_eta_in_one_d


def plot_figures(scenarios, path):
    ns = len(scenarios)
    figure_args = dict(figsize=(ns * 1.8, 2.5), dpi=600)

    fig, ax = plt.subplots(**figure_args,
                           ncols=ns, sharey=True)
    for i, eta in enumerate(scenarios):
        plt.sca(ax[i])
        plot_eta_in_one_d(pd.read_pickle(f"results/eta_{eta}").values)
        if i > 0:
            plt.ylabel(None)
        plt.xlabel(f"Scenario {i + 1}")
    plt.tight_layout()
    fig.savefig(f"{path}eta_scenarios")

    fig_dinf, axes_dinf = plt.subplots(ncols=len(scenarios), sharey=True,
                                       **figure_args)

    for i, eta in enumerate(scenarios):

        name = f"scenario_{eta}"

        experiments = [Experiment(m, 0, name) for m in models]
        for exp in experiments:
            exp.load()

        plt.sca(axes_dinf[i])
        for exp in experiments:
            plot(exp, metric="inf_norm", kwargs_fill={"alpha": 0.2},
                 kwargs_plot={"label": f"{exp.method.name}"})
        if i == 0:
            plt.ylabel("1 - $d_\infty$")
        plt.ylim((0.35, 1))
    lines_labels = [plt.gca().get_legend_handles_labels() for _ in fig.axes]
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
    fig_dinf.legend(lines[:len(models)], labels[:len(models)],
                    loc='upper center', ncol=4, frameon=True,
                    bbox_to_anchor=(0.5, 1.02))
    fig_dinf.savefig(f"{path}dinf")

    fig_dinf, axes_dinf = plt.subplots(ncols=len(scenarios),
                                       **figure_args)

    for i, eta in enumerate(scenarios):

        name = f"scenario_{eta}"
        ref = Experiment(models[-1], 0, name)
        ref.load()
        x = 1 - ref._table.groupby("n_sample").mean()["inf_norm"].values
        experiments = [Experiment(m, 0, name) for m in models]
        for exp in experiments:
            exp.load()

        plt.sca(axes_dinf[i])
        for exp in experiments:
            plot(exp, metric="inf_norm", kwargs_fill={"alpha": 0.2},
                 kwargs_plot={"label": f"{exp.method.name}"}, x=x)
        if i == 0:
            plt.ylabel("1 - $d_\infty$")
    lines_labels = [plt.gca().get_legend_handles_labels() for _ in fig.axes]
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
    fig_dinf.legend(lines[:len(models)], labels[:len(models)],
                    loc='upper center', ncol=4, frameon=True,
                    bbox_to_anchor=(0.5, 1.02))
    fig_dinf.savefig(f"{path}dinf_ref")

    fig_dinf, axes_dinf = plt.subplots(ncols=len(scenarios),
                                       dpi=figure_args["dpi"],
                                       figsize=(ns * 1.95, 2.5))
    for i, eta in enumerate(scenarios):

        name = f"scenario_{eta}"
        ref = Experiment(models[1], 0, name)
        ref.load()
        y = 1 - ref._table.groupby("n_sample").mean()["inf_norm"].values
        experiments = [Experiment(m, 0, name) for m in models]
        for exp in experiments:
            exp.load()

        plt.sca(axes_dinf[i])
        for exp in experiments:
            plot(exp, metric="inf_norm", kwargs_fill={"alpha": 0.2},
                 kwargs_plot={"label": f"{exp.method.name}"}, y=y)
        if i == 0:
            plt.ylabel("1 - $d_\infty$")
    lines_labels = [plt.gca().get_legend_handles_labels() for _ in fig.axes]
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
    # plt.tight_layout()
    fig_dinf.legend(lines[:len(models)], labels[:len(models)],
                    loc='upper center', ncol=4, frameon=True,
                    bbox_to_anchor=(0.5, 1.02))
    fig_dinf.savefig(f"{path}dinf_ref_ratio")


# plot the grid data
selected_eta = [2, 3, 4, 8]
plot_figures(scenarios=selected_eta, path="results/figures/article/s_")
plot_figures(list(range(1, 9)), path="results/figures/article/all_")

# plot illustration of ROC curve

levels_ = [0.1, 0.2, 0.6]
levels_learned_ = [1, 2, 2]


def eta_function(x, levels):
    y = np.zeros_like(x)
    for i, l in enumerate(levels):
        loc = (i / len(levels) <= x) & (x < (i + 1) / len(levels))
        y[loc] = l
    return y


x_test = np.linspace(0, 0.99, num=10000)
eta_true = eta_function(x_test, levels_)
y_test = eta_true > np.random.uniform(0, 1, size=len(x_test))
eta_appx = eta_function(x_test, levels_learned_)

mini_size = dict(figsize=(2.5, 2.5), dpi=300)
plt.figure(**mini_size)
roc_true = roc.ROC(eta_true, y_test)
roc_ = roc.ROC(eta_appx, y_test)
roc_true.plot_roc_curve(label="ROC$^*$")
roc_.plot_roc_curve(label="ROC$(\\tilde s)$")
plt.legend()
roc_true.plot_back_ground()
plt.tight_layout()
plt.savefig("results/figures/article/illustrate_roc")

# plot illustration of complexity
eta = pd.read_pickle(f"results/eta_1")

plt.figure(figsize=(3, 2.5), dpi=mini_size["dpi"])
plot_eta_in_one_d(eta)
c = problem_complexity(eta.values)
plt.ylim((0, 1))
ax2 = plt.gca().twinx()
plt.sca(ax2)
plot_eta_in_one_d(c, color=config.colors[2], order=False)
ax2.grid(False)
ax2.set_ylabel("Complexity", color=config.colors[2])
plt.tight_layout()
plt.savefig(f"results/figures/article/levels_eta_1")
