from typing import Mapping

import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from tabular_mvdrl.utils.discrete_distributions import DiscreteDistribution

MAX_ALPHA = 0.8
MIN_ALPHA = 0.2


def plot_univariate_empirical_distributions(data: Mapping[str, DiscreteDistribution]):
    fig, ax = plt.subplots()
    for tag, dist in data.items():
        ax.hist(
            dist.locs,
            density=True,
            weights=dist.probs,
            ec="black",
            label=tag,
            alpha=0.5,
        )
    ax.set_xlabel("Return")
    ax.set_ylabel("Probability")
    ax.legend()
    fig.tight_layout()
    return fig


def plot_bivariate_empirical_distributions_kde(
    data: Mapping[str, DiscreteDistribution], scatter_tag: str | None = None
):
    sns.set_theme("talk")
    sns.set_style("whitegrid")
    mc_dist = data["MC"]
    tags = ["MC"] * len(mc_dist.probs)
    df = pd.DataFrame(
        {
            "Cumulant 1": mc_dist.locs[:, 0],
            "Cumulant 2": mc_dist.locs[:, 1],
            "Probability": mc_dist.probs,
            "Method": tags,
        }
    )

    fig, ax = plt.subplots()
    sns.kdeplot(
        df,
        x="Cumulant 1",
        y="Cumulant 2",
        weights="Probability",
        # hue="Method",
        fill=True,
        alpha=0.5,
        ax=ax,
        warn_singular=False,
        label="MC",
    )

    if scatter_tag is not None:
        scatter_data = data[scatter_tag]
        scatter_locs = scatter_data.locs
        # scatter_probs = scatter_data.probs.clip(0.0, 1.0)
        scatter_probs = scatter_data.probs
        max_prob = jnp.max(jnp.abs(scatter_probs)).item()
        positive_prob_mask = scatter_probs >= 0
        negative_prob_mask = scatter_probs < -1e-4
        positive_locs = scatter_locs[positive_prob_mask]
        positive_probs = scatter_probs[positive_prob_mask]
        negative_locs = scatter_locs[negative_prob_mask]
        negative_probs = -scatter_probs[negative_prob_mask]
        if max_prob > 1.0:
            positive_weights = MAX_ALPHA * positive_probs / max_prob
            negative_weights = MAX_ALPHA * negative_probs / max_prob
        elif max_prob < MIN_ALPHA:
            positive_weights = positive_probs * MIN_ALPHA / max_prob
            negative_weights = negative_probs * MIN_ALPHA / max_prob
        else:
            positive_weights = positive_probs
            negative_weights = negative_probs
        ax.scatter(
            positive_locs[:, 0],
            positive_locs[:, 1],
            color=[(1.0, 0.0, 0.0, min(w.item(), 1.0)) for w in positive_weights],
            marker="o",
            label=scatter_tag,
        )
        if len(negative_probs > 0):
            ax.scatter(
                negative_locs[:, 0],
                negative_locs[:, 1],
                color=[(1.0, 0.0, 0.0, min(w.item(), 1.0)) for w in negative_weights],
                marker="x",
                label=f"{scatter_tag} (-)",
            )
        ax.legend()
    fig.tight_layout()
    return fig


def plot_bivariate_empirical_distributions(
    data: Mapping[str, DiscreteDistribution], scatter_tag: str | None = None
):
    cmap = plt.get_cmap("viridis")
    n_tags = len(list(data.keys()))
    denom = max(n_tags - 1, 1)
    max_alpha = 0.7
    fig, ax = plt.subplots()
    for i, (tag, dist) in enumerate(data.items()):
        weights = dist.probs / jnp.max(dist.probs)
        for loc, prob in zip(dist.locs, weights):
            ax.scatter(*loc, alpha=max_alpha * prob.item(), color=cmap(i / denom))
    return fig
