"""
Create plots illustrating sortability for increasing graph size (typically convergence to high values) for different weight distribution parameters.
"""
import CDExperimentSuite_DEV as CDES
from auxiliary.vsb_emergence import vsb_investigation, approx_b
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from copy import deepcopy

plt.rcParams["text.usetex"] = True
plt.rcParams["text.latex.preamble"] = r"\usepackage{bm}\usepackage{amsfonts}"
plt.rcParams["axes.labelsize"] = 25
plt.rcParams["xtick.labelsize"] = 18
plt.rcParams["ytick.labelsize"] = 18
plt.rcParams["legend.fontsize"] = 18
plt.rcParams["legend.title_fontsize"] = 16
plt.rcParams["lines.linewidth"] = 3
plt.rcParams["lines.markersize"] = 12


def graph_size_exp(opt, vsb_functions):
    """experiment with graph size"""
    assert len(vsb_functions) == 1
    vname = list(vsb_functions.keys())[0]
    w_ranges = []
    for i in np.arange(-1.0, 3, 0.5):
        a = 0.1
        res = approx_b(a=a, k=i)
        w_ranges.append((a, res[0]))
    opt["exp_name"] = f"{vname}_convergence_size"
    opt["edges"] = [2]
    opt["edge_weights"] = w_ranges
    opt = CDES.utils.Options(**opt)

    def plot_line_style(df):
        # create legend column
        a = df.w_low.astype(str)
        b = df.w_high.astype(str).apply(lambda x: x if len(x) < 5 else x)
        df["legend"] = (
            r"$\pm$("
            + a
            + ","
            + b
            + "); "
            + r"$\mathrm{\mathbb{E}}[\ln|V|]$="
            + df["Eln|w|"].astype(str)
        )
        buffer = 3  # buffer to get colors centered around 0
        cmap = plt.get_cmap("coolwarm", 8 + buffer)
        rgb_colors = [cmap(i)[:3] for i in range(buffer, cmap.N)]
        p = sns.color_palette(rgb_colors)
        # start figure
        _, ax = plt.subplots(figsize=(10, 4))
        sns.lineplot(data=df, x="d", y=vname, hue="legend", palette=p, ax=ax)
        plt.ylim(0.45, 1.0)
        plt.yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
        handles, labels = ax.get_legend_handles_labels()
        handles = handles[::-1]
        labels = labels[::-1]
        plt.legend(
            handles, labels, title="Weights", bbox_to_anchor=(1.02, 1), fontsize=14
        )
        plt.xlabel("Nodes")

    vsb_investigation(
        opt, vsb_functions=vsb_functions, plot_fun=plot_line_style, verbose=True
    )


def density_exp(opt, vsb_functions):
    """experiment with density"""
    assert len(vsb_functions) == 1
    vname = list(vsb_functions.keys())[0]
    opt["exp_name"] = f"{vname}_convergence_density"
    opt["edges"] = [0.1, 0.4, 4, 8, 16]
    opt["edge_weights"] = [(0.1, 0.3)]
    opt = CDES.utils.Options(**opt)

    def plot_line_style(df):
        _, ax = plt.subplots(figsize=(8.5, 7))
        sns.lineplot(
            data=df,
            x="d",
            y=vname,
            hue="x",
            palette=plt.cm.get_cmap("viridis", 8),
            ax=ax,
        )
        plt.legend(title=r"\lambda")
        plt.xlabel("Nodes")

    vsb_investigation(opt, vsb_functions=vsb_functions, plot_fun=plot_line_style)


if __name__ == "__main__":
    opt = {
        "overwrite": False,
        "base_dir": f"src/results/Convergence/50",
        "exp_name": "_raw",
        # ---
        "MEC": False,
        "thres": 0,
        "thres_type": "standard",
        "vsb_function": CDES.utils.var_sortability,
        "R2sb_function": CDES.utils.r2_sortability,
        "CEVsb_function": CDES.utils.cev_sortability,
        # ---
        "n_repetitions": 10,
        "graphs": None,
        "edge_types": ["fixed"],
        "noise_distributions": [
            CDES.utils.NoiseDistribution("gauss", "uniform", (0.5, 2.0)),
        ],
        "scaler": CDES.Scalers.Identity(),
        "n_nodes": list(np.arange(5, 31, 3)),
        "n_obs": [1000],
    }

    vsb_functions = {r"$R^2$-sortability": CDES.utils.r2_sortability}

    for gt in [["ER"], ["SF"]]:
        opt_g = deepcopy(opt)
        opt_g["base_dir"] += "_" + gt[0]
        opt_g["graphs"] = gt
        for k, v in vsb_functions.items():
            graph_size_exp(opt_g, {k: v})
