from functools import partial

import numpy as np
from scipy import stats
from sklearn.metrics import accuracy_score, average_precision_score, log_loss, dcg_score, ndcg_score, roc_auc_score
from sklearn.mixture import GaussianMixture

from utils import ranking_accuracy, hitrate_at_k #precision_at_k, recall_at_k

METRIC_DICT = {
    "ndcg": ndcg_score,
    "dcg": dcg_score,
    "spearman": lambda x, y: stats.spearmanr(-x, y).statistic, # need to reverse lambda 
    "kendall_tau": lambda x, y: stats.kendalltau(-x, y).statistic,
    "rank_acc": ranking_accuracy,
    #"prec@5": partial(precision_at_k, 5),
    #"rec@5": partial(recall_at_k, 5), # TODO: pick more values of k
}
for k in range(1, 11):
    METRIC_DICT[f"hitrate@{k}"] = partial(hitrate_at_k, k)
    #METRIC_DICT[f"rec@{k}"] = partial(recall_at_k, k) -- rename to hit rate; prec == rec for the definition

PRED_METRIC_DICT = {
    "accuracy": accuracy_score,
}

PROB_METRIC_DICT = {
    "auroc": roc_auc_score,
    "log_loss": log_loss,
}

PSEUDO_BINARY_METRICS = { # deprecated
    "ap": average_precision_score,
    "auc": roc_auc_score, 
}

def do_counterfactual_evaluations(meta_learner, evaluator, test_df, plan_df):
    results = {}
    for tmt in test_df["t"].unique():
        y_cf = test_df[f"d_{tmt}"]
        y_score = meta_learner.predict_proba(test_df, plan_df, t=tmt)
        tmt_results = evaluator.evaluate_predictions(y_cf, y_score)
        for k in tmt_results:
            results[f"cf_{tmt}_{k}"] = tmt_results[k]
    return results


class RankingEvaluator(object):
    def __init__(self, metrics, prediction_metrics, binary_metrics=None):
        self.metrics = metrics 
        self.prediction_metrics = prediction_metrics
        self.binary_metrics = binary_metrics
        self.results = {}

    @classmethod
    def from_config(cls, cfg):
        return cls(cfg["metrics"], cfg["prediction_metrics"], binary_metrics=cfg.get("binary_metrics", None))

    def evaluate(self, y_true, y_score, metric_kwarg_dict={}, save_results=True):
        results = {}
        for metric in self.metrics:
            """if metric in metric_kwarg_dict:
                metric_kwargs = metric_kwarg_dict[metric] # unused
            else:
                metric_kwargs = {} # unused"""
            results[metric] = METRIC_DICT[metric](y_true, y_score)
        if save_results:
            self.results = results
        return results 

    def evaluate_predictions(self, y_true, y_score, threshold=0.5):
        results = {}
        y_pred = (y_score > threshold).astype(int)
        for metric in self.prediction_metrics:
            if metric in PRED_METRIC_DICT:
                results[metric] = PRED_METRIC_DICT[metric](y_true, y_pred)
            elif metric in PROB_METRIC_DICT:
                results[metric] = PROB_METRIC_DICT[metric](y_true, y_score)
            else:
                raise ValueError(f"`{metric}` is not a valid metric.")
        return results 

    def pseudo_multiclass_evaluation(self, y_true, y_score, n_components=2):
        """
            Deprecated. This is not used in the final experiments.
        """
        if self.binary_metrics is None:
            raise ValueError("No binary metrics were passed in.")

        if not isinstance(y_true, np.ndarray):
            y_true = np.array(y_true)
        # Step 1: fit a k component GMM to estimate pseudolabels - if the dataset is strongly bimodal this should work
        init_weights = np.ones(n_components) / n_components
        cluster_inits = np.quantile(y_true, np.linspace(0, 1, n_components)).reshape(-1, 1)
        var_init = y_true.max() - y_true.min()
        gmm = GaussianMixture(
            n_components=n_components,
            covariance_type='spherical',
            weights_init=init_weights,
            means_init=cluster_inits,
            precisions_init=np.ones(n_components) / var_init,
            random_state=42,
        )
        pseudo_y = gmm.fit_predict(y_true.reshape(-1, 1))
        for comp in range(n_components):
            print(f"Cluster {comp}:", y_true[pseudo_y == comp])

        # Step 2: run standard evaluation w.r.t. pseudolabels
        for metric in self.binary_metrics:
            self.results[metric] = PSEUDO_BINARY_METRICS[metric](pseudo_y, y_score)
        return self.results

        
        
