from tqdm import tqdm
import os
import argparse
import pickle
import warnings

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from KDEpy import FFTKDE
from sklearn.metrics.pairwise import pairwise_distances
import scipy
import sklearn


from nnlib.nnlib import utils


def compute_acc(preds, mask, dataset, target_label=None):
    labels = [y for x, y in dataset]
    labels = torch.tensor(labels).long()
    if target_label is not None:
        labels[labels!=target_label] = 0
        labels[labels==target_label] = 1
    indices = 2*np.arange(len(mask)) + mask
    acc = (preds[indices].argmax(dim=1) == labels[indices]).float().mean()
    return utils.to_numpy(acc)

def compute_acc_l2(preds, mask, dataset, target_label=None):
    labels = [y for x, y in dataset]
    labels = torch.tensor(labels)
    if target_label is not None:
        labels[labels!=target_label] = 0
        labels[labels==target_label] = 1
    indices = 2*np.arange(len(mask)) + mask
    if preds[indices].softmax(1).sum(1).sum() == len(preds[indices]):
        preds[indices] = preds[indices].softmax(1)
    conf, _ = preds[indices].max(dim=1)
    conf[labels[indices]==0] = 1 - conf[labels[indices]==0] ## MEMO: Reverse prob. for label y=0.
    acc = ((labels[indices] - conf) ** 2).mean()
    return utils.to_numpy(acc)

def interpolate_nan(a):
    """Linear interpolation for nan values in a 1d array.
    Nans on the boundary are filled with the nearest non-nan value.
    Slightly modified From the code in the "minimum-calibration..." NeurIPS2023.
    """
    b = a.copy()
    nans = np.isnan(b)
    i = np.arange(len(b))
    b[nans] = np.interp(i[nans], i[~nans], b[~nans])
    return torch.tensor(b).float()

def idx_bins(confidence, n_bins):
    binids = np.minimum(np.digitize(confidence.numpy(), n_bins), len(n_bins) - 1)
    binids -= 1
    return torch.tensor(binids)


#def calc_ECE(confidences, labels, num_bins=15, norm='l1', method='uniform', recalibrate=True, strategy='label'):
def calc_ECE(confidences, labels, bins, norm='l1', method='uniform', recalibrate=False, strategy='label'):
    """
    Calcurating non-recalibrate ECE or recalibrate ECE with full training dataset.
    """
    if not torch.all(torch.abs(torch.sum(confidences, dim=1) - 1) < 1e-10):
        print("make softmax prob.")
        confidences = confidences.softmax(1)
    
    if not torch.all((confidences >= 0) & (confidences <= 1)):
        raise ValueError(f"This is not softmax prob.")
    
    confidences, _ = confidences.max(dim=1)
    #confidences[labels==0] = 1 - confidences[labels==0] ## MEMO: Reverse prob. for label y=0.

    #n_bins = compute_bins(num_bins=num_bins, confidences=confidences, method=method)
    n_bins = bins

    with torch.no_grad():
        conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)
        count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)
        label_bin = torch.zeros(len(n_bins), device=labels.device, dtype=labels.dtype)
        idx = idx_bins(confidences, n_bins)

        if recalibrate:
            bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins
            if strategy == 'label':
                bin_true = (torch.bincount(idx, weights=labels, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    
            elif strategy == 'probability':
                bin_true = (torch.bincount(idx, weights=confidences, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels
            else:
                raise ValueError(f"Unexpected strategy: {strategy}.")
            
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')
                # fill nan by interpolation assuming smoothness
                bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \hat{\mu} in Eq.(9) of Sun et al. (2023)
            confidences = bin_mean[idx]

        count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences))
        conf_bin.scatter_add_(dim=0, index=idx, src=confidences)
        conf_bin = torch.nan_to_num(conf_bin / count_bin)
        prop_bin = count_bin / count_bin.sum()
        
        label_bin.scatter_add_(dim=0, index=idx, src=labels)
        label_bin = torch.nan_to_num(label_bin / count_bin)
    
    if norm == 'l1':
        ece = torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)
    elif norm == 'l2':
        ece = torch.sqrt(torch.sum(torch.pow(label_bin - conf_bin, 2) * prop_bin))
    else:
        raise ValueError(f"Unexpected norm type: {norm}")
    
    return ece

#def compute_ece(preds, mask, dataset, num_bins=10, norm='l1', method='uniform', recalibrate=False, strategy='label'):
def compute_ece(preds, mask, dataset, bins, norm='l1', method='uniform', recalibrate=False, strategy='label', target_label=None):
    labels = [y for x, y in dataset]
    labels = torch.tensor(labels).long()
    if target_label is not None:
        labels[labels!=target_label] = 0
        labels[labels==target_label] = 1
    indices = 2*np.arange(len(mask)) + mask
    
    #ece = calc_ECE(preds[indices], labels[indices], num_bins, norm=norm, method=method, recalibrate=recalibrate, strategy=strategy)
    ece = calc_ECE(preds[indices], labels[indices], bins, norm=norm, method=method, recalibrate=recalibrate, strategy=strategy)
    
    return utils.to_numpy(ece)

def compute_bins(num_bins, confidences=None, method='uniform'):
    if method == 'uniform':
        n_bins = torch.linspace(0, 1, num_bins + 1)
        n_bins[0], n_bins[-1] = 0., 1.
    elif method == 'quantile':
        if confidences == None:
            raise ValueError(f"confidence values are needed.")
        n_bins = torch.tensor(np.quantile(confidences, torch.linspace(0, 1, num_bins + 1)))
        n_bins[0], n_bins[-1] = 0., 1.
    else:
        raise ValueError(f"Unexpected binning method: {method}")
    
    return n_bins

def get_bandwidth(f, device):
    """
    Select a bandwidth for the kernel based on maximizing the leave-one-out likelihood (LOO MLE).

    :param f: The vector containing the probability scores, shape [num_samples, num_classes]
    :param device: The device type: 'cpu' or 'cuda'

    :return: The bandwidth of the kernel
    """
    bandwidths = torch.cat((torch.logspace(start=-5, end=-1, steps=15), torch.linspace(0.2, 1, steps=5)))
    max_b = -1
    max_l = 0
    n = len(f)
    for b in bandwidths:
        log_kern = get_kernel(f, b, device)
        #log_fhat = torch.logsumexp(log_kern, 1) - torch.log(n-1)
        log_fhat = torch.logsumexp(log_kern, 1) - np.log(n-1)
        l = torch.sum(log_fhat)
        if l > max_l:
            max_l = l
            max_b = b

    return max_b


def get_ece_kde(f, y, bandwidth, p, mc_type, device, kernel='dirichlet'):
    """
    Calculate an estimate of Lp calibration error.

    :param f: The vector containing the probability scores, shape [num_samples, num_classes]
    :param y: The vector containing the labels, shape [num_samples]
    :param bandwidth: The bandwidth of the kernel
    :param p: The p-norm. Typically, p=1 or p=2
    :param mc_type: The type of multiclass calibration: canonical, marginal or top_label
    :param device: The device type: 'cpu' or 'cuda'

    :return: An estimate of Lp calibration error
    """
    check_input(f, bandwidth, mc_type)
    if f.shape[1] == 1:
        return utils.to_numpy(2 * get_ratio_binary(f, y, bandwidth, p, device, kernel))
    else:
        if mc_type == 'canonical':
            ## Using Dhiricret kernel (K>=2) or Gaussian
            return utils.to_numpy(get_ratio_canonical(f, y, bandwidth, p, device, kernel))
        elif mc_type == 'marginal':
            ## Using beta kernel
            return utils.to_numpy(get_ratio_marginal_vect(f, y, bandwidth, p, device))
        elif mc_type == 'top_label':
            return utils.to_numpy(get_ratio_toplabel(f, y, bandwidth, p, device, kernel))


def get_ratio_binary(f, y, bandwidth, p, device, kernel='dirichlet'):
    assert f.shape[1] == 1

    log_kern = get_kernel(f, bandwidth, device, kernel)

    return get_kde_for_ece(f, y, log_kern, p)

def get_ratio_canonical(f, y, bandwidth, p, device, kernel='dirichlet'):
    if f.shape[1] > 60:
        # Slower but more numerically stable implementation for larger number of classes
        return get_ratio_canonical_log(f, y, bandwidth, p, device, kernel)

    log_kern = get_kernel(f, bandwidth, device, kernel)
    kern = torch.exp(log_kern).to(torch.float32)

    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)
    kern_y = torch.matmul(kern, y_onehot)
    den = torch.sum(kern, dim=1)
    # to avoid division by 0
    den = torch.clamp(den, min=1e-10)

    ratio = kern_y / den.unsqueeze(-1)
    ratio = torch.sum(torch.abs(ratio - f)**p, dim=1)

    return torch.mean(ratio)

# Note for training: Make sure there are at least two examples for every class present in the batch, otherwise
# LogsumexpBackward returns nans.
def get_ratio_canonical_log(f, y, bandwidth, p, device, kernel='dirichlet'):
    log_kern = get_kernel(f, bandwidth, device, kernel)
    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)
    log_y = torch.log(y_onehot)
    log_den = torch.logsumexp(log_kern, dim=1)
    final_ratio = 0
    for k in range(f.shape[1]):
        log_kern_y = log_kern + (torch.ones([f.shape[0], 1]) * log_y[:, k].unsqueeze(0))
        log_inner_ratio = torch.logsumexp(log_kern_y, dim=1) - log_den
        inner_ratio = torch.exp(log_inner_ratio)
        inner_diff = torch.abs(inner_ratio - f[:, k])**p
        final_ratio += inner_diff

    return torch.mean(final_ratio)

def get_ratio_marginal_vect(f, y, bandwidth, p, device):
    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32).to(device)
    log_kern_vect = beta_kernel(f, f, bandwidth).squeeze().to(device)
    log_kern_diag = torch.diag(torch.finfo(torch.float).min * torch.ones(len(f))).to(device)
    # Multiclass case
    log_kern_diag_repeated = f.shape[1] * [log_kern_diag]
    log_kern_diag_repeated = torch.stack(log_kern_diag_repeated, dim=2)
    log_kern_vect = log_kern_vect + log_kern_diag_repeated

    return get_kde_for_ece_vect(f.to(device), y_onehot, log_kern_vect, p)


def get_ratio_toplabel(f, y, bandwidth, p, device, kernel):
    f_max, indices = torch.max(f, 1)
    f_max = f_max.unsqueeze(-1)
    y_max = (y == indices).to(torch.int)

    return get_ratio_binary(f_max, y_max, bandwidth, p, device, kernel)


def get_kde_for_ece_vect(f, y, log_kern, p):
    log_kern_y = log_kern * y
    # Trick: -inf instead of 0 in log space
    log_kern_y[log_kern_y == 0] = torch.finfo(torch.float).min

    log_num = torch.logsumexp(log_kern_y, dim=1)
    log_den = torch.logsumexp(log_kern, dim=1)

    log_ratio = log_num - log_den
    ratio = torch.exp(log_ratio)
    ratio = torch.abs(ratio - f)**p

    return torch.sum(torch.mean(ratio, dim=0))


def get_kde_for_ece(f, y, log_kern, p):
    f = f.squeeze()
    N = len(f)
    # Select the entries where y = 1
    idx = torch.where(y == 1)[0]
    if not idx.numel():
        return torch.sum((torch.abs(-f))**p) / N

    if idx.numel() == 1:
        # because of -inf in the vector
        log_kern = torch.cat((log_kern[:idx], log_kern[idx+1:]))
        f_one = f[idx]
        f = torch.cat((f[:idx], f[idx+1:]))

    log_kern_y = torch.index_select(log_kern, 1, idx)

    log_num = torch.logsumexp(log_kern_y, dim=1)
    log_den = torch.logsumexp(log_kern, dim=1)

    log_ratio = log_num - log_den
    ratio = torch.exp(log_ratio)
    ratio = torch.abs(ratio - f)**p

    if idx.numel() == 1:
        return (ratio.sum() + f_one ** p)/N

    return torch.mean(ratio)


def get_kernel(f, bandwidth, device, kernel='dirichlet'):
    if kernel == 'dirichlet':
        # if num_classes == 1
        if f.shape[1] == 1:
            log_kern = beta_kernel(f, f, bandwidth).squeeze().to(device)
        else:
            log_kern = dirichlet_kernel(f, bandwidth).squeeze().to(device)
    elif kernel == 'gaussian':
        log_kern = gaussian_kernel(f, bandwidth).squeeze()
    else:
        raise ValueError(f"Unexpected kernel: {kernel}.")
    
    # Trick: -inf on the diagonal
    return log_kern + torch.diag(torch.finfo(torch.float).min * torch.ones(len(f))).to(device)


def beta_kernel(z, zi, bandwidth=0.1):
    p = zi / bandwidth + 1
    q = (1-zi) / bandwidth + 1
    z = z.unsqueeze(-2)

    log_beta = torch.lgamma(p) + torch.lgamma(q) - torch.lgamma(p + q)
    log_num = (p-1) * torch.log(z) + (q-1) * torch.log(1-z)
    log_beta_pdf = log_num - log_beta

    return log_beta_pdf


def dirichlet_kernel(z, bandwidth=0.1):
    alphas = z / bandwidth + 1

    log_beta = (torch.sum((torch.lgamma(alphas)), dim=1) - torch.lgamma(torch.sum(alphas, dim=1)))
    log_num = torch.matmul(torch.log(z), (alphas-1).T)
    log_dir_pdf = log_num - log_beta

    return log_dir_pdf

def gaussian_kernel(z, bandwidth=0.1, median=True):
    Dxx = pairwise_distances(z.numpy(), metric='sqeuclidean')
    if median:
        bandwidth = np.median(Dxx)
    log_gauss_pdf = -Dxx / 2 / bandwidth
        
    return torch.tensor(log_gauss_pdf).float()


def check_input(f, bandwidth, mc_type):
    assert not isnan(f)
    assert len(f.shape) == 2
    assert bandwidth > 0
    assert torch.min(f) >= 0
    assert torch.max(f) <= 1


def isnan(a):
    return torch.any(torch.isnan(a))


def compute_ece2(preds, mask, dataset, ece_loss, target_label=None):
    labels = [y for x, y in dataset]
    labels = torch.tensor(labels).long()
    if target_label is not None:
        labels[labels!=target_label] = 0
        labels[labels==target_label] = 1
    indices = 2*np.arange(len(mask)) + mask

    ece = ece_loss.forward(preds[indices], labels[indices])
    
    return utils.to_numpy(ece)

class _ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model (Slightly modified by M.F.).
    (This isn't necessary for temperature scaling, just a cool metric).

    The input to this loss is the logits of a model, NOT the softmax scores.

    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:

    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |

    We then return a weighted average of the gaps, based on the number
    of samples in each bin

    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15, method='uniform', logits=None):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        if method == 'uniform':
            bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        elif method == 'quantile':
            conf, pred = logits.max(1)
            conf[pred==0] = 1 - conf[pred==0]
            bin_boundaries = torch.tensor(np.quantile(conf, torch.linspace(0, 1, n_bins + 1)))
        
        bin_boundaries[0], bin_boundaries[-1] = 0., 1.
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        if not torch.all(torch.abs(torch.sum(logits, dim=1) - 1) < 1e-10):
            print("make softmax prob.")
            softmaxes = F.softmax(logits, dim=1)
        else:
            softmaxes = logits
        confidences, predictions = torch.max(softmaxes, 1)
        confidences[predictions==0] = 1 - confidences[predictions==0]
        accuracies = labels
        #accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

############ For recalibration through training data reusing ############
def compute_lsloss(preds, mask, dataset, loss, target_label=None):
    labels = [y for x, y in dataset]
    labels = torch.tensor(labels).long()
    if target_label is not None:
        labels[labels!=target_label] = 0
        labels[labels==target_label] = 1
    
    acc_loss, bin_loss = loss.forward(preds, labels, mask)
    
    return (utils.to_numpy(acc_loss), utils.to_numpy(bin_loss))

class _LabelLoss(nn.Module):
    def __init__(self, n_bins=15, method='uniform', logits=None):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_LabelLoss, self).__init__()
        if method == 'uniform':
            bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        elif method == 'quantile':
            conf, pred = logits.max(1)
            conf[pred==0] = 1 - conf[pred==0]
            bin_boundaries = torch.tensor(np.quantile(conf, torch.linspace(0, 1, n_bins + 1)))
        
        bin_boundaries[0], bin_boundaries[-1] = 0., 1.
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, masks):
        if not torch.all(torch.abs(torch.sum(logits, dim=1) - 1) < 1e-10):
            print("make softmax prob.")
            softmaxes = F.softmax(logits, dim=1)
        else:
            softmaxes = logits

        train_idx = 2*np.arange(len(masks)) + masks
        test_idx = 2*np.arange(len(1-masks)) + (1-masks)
        num_data = len(train_idx)
        print(num_data)

        pred_tr, pred_te = softmaxes[train_idx], softmaxes[test_idx]
        ls_tr, ls_te = labels[train_idx], labels[test_idx]
        
        conf_tr, pl_tr = torch.max(pred_tr, 1)
        conf_tr[pl_tr==0] = 1 - conf_tr[pl_tr==0]
        conf_te, pl_te = torch.max(pred_te, 1)
        conf_te[pl_te==0] = 1 - conf_tr[pl_te==0]
        
        acc = torch.zeros(1, device=pred_tr.device)
        bin_loss = torch.zeros(1, device=pred_tr.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = conf_tr.gt(bin_lower.item()) * conf_tr.le(bin_upper.item())
            in_bin_te = conf_te.gt(bin_lower.item()) * conf_te.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            prop_in_bin_te = in_bin_te.float().mean()
            if prop_in_bin.item() > 0 and prop_in_bin_te.item() > 0:
                freq_tr = ls_tr[in_bin].float()
                freq_te = ls_te[in_bin_te].float()
                acc += abs((freq_te.float().sum() - freq_tr.float().sum())/ num_data)
                bin_loss += prop_in_bin_te
            elif prop_in_bin.item() > 0:
                freq_tr = ls_tr[in_bin].float()
                acc += freq_tr.float().sum() / num_data
            elif prop_in_bin_te.item() > 0:
                freq_te = ls_te[in_bin_te].float()
                acc += freq_te.float().sum() / num_data
                bin_loss += prop_in_bin_te
            else:
                acc += 0

        return (acc, bin_loss)