import numpy as np
import torch
from sklearn.metrics import roc_auc_score
import joblib

class CustomEvalMetrics:
    def __init__(self) -> None:
        pass

    @staticmethod
    def fp(probas: np.array, labels: np.array, threshold: float):
        """Calculate the false possibitiuve rate for a given threshold"""
        return np.sum((probas > threshold) & (labels == 0))

    @staticmethod
    def tp(probas: np.array, labels: np.array, threshold: float):
        """Calculate the true possibitiuve rate for a given threshold"""
        return np.sum((probas > threshold) & (labels == 1))

    @staticmethod
    def fn(probas: np.array, labels: np.array, threshold: float):
        """Calculate the false negative rate for a given threshold"""
        return np.sum((probas < threshold) & (labels == 1))

    @staticmethod
    def tn(probas: np.array, labels: np.array, threshold: float):
        """Calculate the true negative rate for a given threshold"""
        return np.sum((probas < threshold) & (labels == 0))

    def rejection_rate(self, probas: np.array, labels: np.array, n_split=1):
        """Calculate the rejection rate for a given threshold and probability"""
        thresholds = np.linspace(0.01, 1, 1000)
        all_rejs = []
        all_tprs = []

        parted_probs = np.array_split(probas, n_split)
        parted_labels = np.array_split(labels, n_split)
        for n in range(n_split):
            rejs = []
            tprs = []
            P = np.sum(parted_labels[n] == 1)
            N = np.sum(parted_labels[n] == 0)
            for th in thresholds:
                tpr = self.tp(parted_probs[n], parted_labels[n], th) / P
                rej = N / self.fp(parted_probs[n], parted_labels[n], th)
                tprs.append(tpr)
                rejs.append(rej)
            all_rejs.append(rejs)
            all_tprs.append(tprs)
        return all_tprs, all_rejs

    def rej_at_eff(self, tprs, rejs, efficiency: float):
            """Calculate the rejection rate for a given threshold"""
            diffs = np.abs(np.array(tprs) - efficiency)
            return rejs[np.argmin(diffs)]

    def __call__(self,
                 probas: np.array,
                 labels: np.array,
                 threshold: float,
                 *,
                 write_to: str = "",
                 save_data: bool = False
                 ):
        """Calculate the false possibitiuve rate for a given threshold"""

        acc = np.sum((probas > threshold) == labels) / len(labels)
        prs, rejs = self.rejection_rate(probas, labels)
        rej_rate_30 = self.rej_at_eff(prs, rejs, 0.3)
        rej_rate_50 = self.rej_at_eff(prs, rejs, 0.5)
        auc_score = roc_auc_score(labels, probas)

        metrics_dict = {
         "accuracy": acc,
         "rejection_rate_30": rej_rate_30,
         "rejection_rate_50": rej_rate_50,
         "auc_score": auc_score
        }

        if write_to:
            joblib.dump(metrics_dict, write_to)
        return metrics_dict