import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from active_ranking import config
from active_ranking.metrics.roc import ROC


def steps(array):
    return np.array([i for i in array for _ in range(2)])


def plot(exp, metric="one_norm",
         kwargs_fill=None, kwargs_plot=None,
         method="std", plot_shades=True, plot_stopping_time=False, x="sample",
         y=None):
    kwargs_fill = {} if kwargs_fill is None else kwargs_fill
    kwargs_plot = {} if kwargs_plot is None else kwargs_plot

    if method == "std":
        shade_value = exp.mean_n_sample[metric] - \
                      exp.std_n_sample[
                          metric], exp.mean_n_sample[metric] + \
                      exp.std_n_sample[metric]
    elif method == "95":
        shade_value = exp.mean_n_sample[metric] - exp.quantile_95[
            metric], exp.mean_n_sample[metric] + exp.quantile_95[
                          metric]
    elif method == "minmax":
        shade_value = exp.mean_n_sample[metric] - exp.quantile_95[
            metric], exp.mean_n_sample[metric] + exp.quantile_95[
                          metric]
    else:
        raise ValueError(f"Method {method} is not known")
    if x == "sample":
        x_plot = exp.mean_n_sample.index
    else:
        x_plot = x
    if y is None:
        y = np.ones_like(x)
    if plot_shades:
        plt.fill_between(
            x_plot,
            (1 - shade_value[0]) / y,
            (1 - shade_value[1]) / y,
            **kwargs_fill

        )
    if not plot_stopping_time:
        plt.plot(x_plot,
                 (1 - exp.mean_n_sample[metric]) / y, **kwargs_plot)
    else:
        try:
            selection = \
                np.where(exp.mean_n_sample[metric] < config.epsilon)[0][
                    0]
            selection = max(selection, 1)
            p = plt.plot(
                x_plot[selection - 1:],
                exp.mean_n_sample[metric][selection - 1:],
                **{**kwargs_plot, **{"ls": "--", "label": None}})

            plt.plot(
                x_plot[:selection],
                exp.mean_n_sample[metric][:selection],
                **{**kwargs_plot, **{"ls": "-", "color": p[0].get_color()}})
        except:
            plt.plot(exp.mean_n_sample.index,
                     exp.mean_n_sample[metric], **kwargs_plot)


def plot_lines(exp, metric="one_norm", kwargs_plot=None, kwargs_fill=None):
    kwargs_fill = {} if kwargs_fill is None else kwargs_fill
    kwargs_plot = {} if kwargs_plot is None else kwargs_plot
    df: pd.DataFrame = exp._table.copy()
    df[metric] = 1 - df[metric].round(2)
    df = df.sort_values("n_sample").drop_duplicates(["id", metric],
                                                    keep="first")
    df_tra: pd.DataFrame = df[["id", "n_sample", metric]].set_index(
        ["id", metric]
    ).unstack(level=-1)
    df_tra.columns = df_tra.columns.to_frame()[metric].values
    df_tra = df_tra.fillna(method="bfill", axis=1)

    columns = df_tra.loc[:, df_tra.isna().mean() < 0.1].columns
    df_tra = df_tra[columns]

    df_tra = df_tra.loc[:, df_tra.columns.to_numpy() < 1 - config.epsilon]
    std = df_tra.std()
    m = df_tra.mean()

    plt.plot(m, **kwargs_plot)
    plt.fill_between(m.index, m - std, m + std, **kwargs_fill)


def plot_true_roc_curves(
        y_true, y_prediction, y_roc_true
):
    plt.figure(dpi=300)
    plt.grid(True)
    roc = ROC(y_true=y_true,
              y_prediction=y_prediction)
    roc.plot_roc_curve(c=config.colors[1],
                       label=f"Empirical "
                             f"ROC curve "
                             f"(AUC = {np.round(roc.auc, 2)})",
                       marker=".", lw=1)

    roc_true = ROC(y_true=y_true,
                   y_prediction=y_roc_true)

    plt.plot(roc_true.c_alpha, roc_true.c_beta + 0.003,
             c=config.colors[0],
             label=f"True ROC "
                   f"curve "
                   f"(AUC = {np.round(roc_true.auc, 2)})",
             marker=".", lw=1)
    roc.plot_back_ground()
    plt.legend(loc=4)


def plot_d1_norm(learner):
    plt.figure()
    plt.plot(learner.n_sample, learner.norm_one,
             label="$d_1$", color=config.colors[0])
    plt.legend(loc=1)


def plot_d_inf_norm(learner):
    plt.figure()
    ax = plt.gca()
    ax.plot(learner.n_sample,
            learner.norm_infinity, label="$d_\infty$", color=config.colors[0])
    ax.axhline(learner.epsilon, label="$\\varepsilon$", color=config.colors[1])
    plt.legend(loc=1)


def plot_swarm_sampling(learner, animate=False):
    import seaborn as sns
    from matplotlib.animation import FuncAnimation
    plt.figure()
    data_plot = pd.DataFrame()
    data_plot["epoch"] = learner.partition.loc_step
    data_plot["cell"] = learner.partition.loc
    ret = pd.Series(learner.ret_p_cells["ranked_estimates"])
    ret.name = "ranked_estimates"
    ret = pd.DataFrame(ret)
    ret["order"] = range(len(ret))
    data_plot = pd.merge(data_plot, ret, right_index=True, left_on="cell")
    data_plot = data_plot.sort_values(["order", "epoch"])
    if not animate:
        sns.swarmplot(data_plot, x="cell", hue="epoch", y='epoch',
                      palette="rocket")
        plt.grid(True)
        plt.tight_layout()
    else:
        def animate(i):
            sns.swarmplot(data_plot[data_plot["epoch"] <= i], x="cell",
                          hue="epoch",
                          y='epoch', palette="rocket")
            plt.legend([], [], frameon=False)

        fig = plt.figure()
        animation = FuncAnimation(fig, animate, frames=max(data_plot["epoch"]),
                                  interval=300)
        return animation


def plot_roc_epochs(learner):
    plt.figure()
    n_epochs = len(learner.predictions.keys())
    for epoch in learner.predictions.keys():
        # legend = f"epoch {epoch} (AUC = {np.round(roc.auc, 2)})" if (
        #        epoch % 50 == 0) else None
        c_ = config.cmap(epoch / (max(learner.predictions.keys()) - 1))
        roc = ROC(y_prediction=learner.predictions[epoch],
                  y_true=learner.y_test)

        plt.plot(
            roc.c_alpha,
            roc.c_beta + 0.0001 * epoch,
            label=None,
            marker=".", color=c_, alpha=0.8, lw=(0.99) ** n_epochs)
    roc.plot_back_ground()
