"""
Score functions for abstaining classifiers
"""

from typing import Union, Dict
import numpy as np
from numpy.typing import ArrayLike

import comparecast as cc
from comparecast.scoring import ScoringRule, ZeroOneScore


def compute_scores(
        predictions: np.ndarray,
        abstentions: np.ndarray,
        labels: np.ndarray,
        scoring_rule: Union[str, ScoringRule] = ZeroOneScore(),
        compute_se: bool = False,
) -> Dict:
    """Evaluate an abstaining classifier's selective score, coverage, and oracle counterfactual score."""
    non_absts = np.logical_not(abstentions)
    score_fn = cc.get_scoring_rule(scoring_rule)
    selective_score = score_fn(predictions[non_absts], labels[non_absts])
    coverage = non_absts
    oracle_cf_score = score_fn(predictions, labels)
    if compute_se:
        return {
            "selective_score": (selective_score.mean(), selective_score.std() / np.sqrt(len(selective_score))),
            "coverage": (coverage.mean(), coverage.std() / np.sqrt(len(coverage))),
            "oracle_cf_score": (oracle_cf_score.mean(), oracle_cf_score.std() / np.sqrt(len(oracle_cf_score))),
        }
    else:
        return {
            "selective_score": selective_score.mean(),
            "coverage": coverage.mean(),
            "oracle_cf_score": oracle_cf_score.mean(),
        }

#           roc_auc_score(df_test["y"],
#                         df_test[[f"{clf_name}_prob{c}" for c in range(n_classes)]],
#                         average="macro",
#                         multi_class="ovr"))


def compute_chow_score(
        predictions: ArrayLike,
        abstentions: ArrayLike,
        labels: ArrayLike,
        scoring_rule: Union[str, ScoringRule] = ZeroOneScore(),
        gamma: float = 0.1,
        reduction: str = "mean",
        **kwargs
) -> float:
    """A generalized Chow's score for abstaining classifiers.

        chow_score((p, a), y) = scoring_rule(p[~a], y[~a]) + gamma * #(~a).
    """
    non_absts = np.logical_not(abstentions)
    selective_score = scoring_rule(predictions[non_absts], labels[non_absts])
    coverage = non_absts.astype(float)
    score = selective_score + gamma * coverage

    reduction = reduction.lower()
    if reduction == "mean":
        return score.mean()
    elif reduction == "sum":
        return score.sum()
    elif reduction == "none":
        return score
    else:
        raise ValueError(f"unrecognized reduction method: {reduction}")

