'''
Caliberate uncertainty given multiple forward passes
input
    y: shape = (N, T, C)
    y is the softmax output of the network
    T is the number of output
    C is the the number of classes
output
    u: shape = (N,)
    uncertainty metric
'''
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np 



def predictive_entropy(y):
    y = F.softmax(y)
    return -torch.sum(y * torch.log(y), dim=1)

def mean_class_variance(y):
    ''' Understanding Measures of Uncertainty for Adversarial Example: Eq[9] '''
    ''' mean variance across classes'''
    return y.var(dim=1, unbiased=False).mean(dim=1)

def max_variance(y):
    max_scores = torch.max(y, dim=2)[0]
    return max_scores.var(dim=1, unbiased=False)

def max_prob(y):
    max_scores = torch.max(y, dim=1)[0]
    return max_scores

def reject_score(y):
    print(y[:,-1])
    scores = 1/2.2 * y[:,-1]
    return scores

def target_variance(y, tar):
    tar_score = get_tar_score(y, tar)
    return tar_score.var(dim=1, unbiased=False)

def target_mean(y, tar):
    tar_score = get_tar_score(y, tar)
    return tar_score.mean(dim=1)

def mutual_information(y):
    ''' Understanding Measures of Uncertainty for Adversarial Example: Eq[4] '''
    m = y.mean(dim=1)
    entropy_mean = -torch.sum(m * torch.log(m), dim=1)
    return entropy_mean - predictive_entropy(y)

def differential_entropy(alphas):
    '''
    Predictive Uncertainty Estimation via Prior Networks: Eq[18]
    '''
    alpha0 = torch.sum(alphas, dim=1, keepdim=True)

    diff_entropy = torch.sum(torch.lgamma(alphas) - (alphas - 1) * (
        torch.digamma(alphas) - torch.digamma(alpha0)), dim=1) - torch.lgamma(alpha0)
    diff_entropy = diff_entropy.diag()
    return diff_entropy

def u_sublogic(alphas):
    alpha0 = torch.sum(alphas, dim=1)
    return alphas.shape[1] / alpha0

def kernel_distance(y):
    return y.max(1)[0]

def accuracy(y, tar):
    _, pred = torch.max(y, 1)
    return pred.eq(tar).mean()

def kl_divergence(y):
    kl = 0
    for i in range(y.shape[1]):
        for j in range(i+1,y.shape[1]):
            kl = kl + torch.sum(F.kl_div(y[:,i,:].log(),y[:,j,:],reduce=False),1)
    # print(kl.shape)
    return kl
def kl_divergence_max(y):
    kl = torch.zeros([y.shape[0],int(y.shape[1]*(y.shape[1]-1)/2)])
    idx = 0
    for i in range(y.shape[1]):
        for j in range(i+1,y.shape[1]):
            kl[:,idx] =  torch.sum(F.kl_div(y[:,i,:].log(),y[:,j,:],reduce=False),1)
            idx += 1
    kl_max = torch.max(kl,dim=1)[0]
    # print(kl_max.shape)
    return kl_max
def margin_tar_mean(logits,tar):
    margins = torch.zeros([logits.shape[0],logits.shape[1]])
    index = 0
    idx = tar.long().reshape(len(tar),1)
    # print(idx.shape)
    for i in range(logits.shape[1]):
        logits_output = logits[:,i,:] # batch * class
        ones_logits = -torch.ones_like(logits_output)*1e5
        logit_target = torch.gather(logits_output,dim=0,index=idx)
        logits_drop_target = torch.where(logits_output==logit_target,ones_logits,logits_output)
        # print(logits_drop_target.shape)
        logit_max = torch.max(logits_drop_target,dim=1)[0]
        # print('max_shape',logit_max.shape)
        margin_with_target = logit_target.squeeze() - logit_max # batch 
        # print(margin_with_target.shape)
        margins[:,index] = margin_with_target
        index +=1
    margins_mean = torch.mean(margins,1)
    return margins_mean
    
def margin_mean(logits,tar):
    margins = torch.zeros([logits.shape[0],logits.shape[1]])
    index = 0
    idx = tar.long().reshape(len(tar),1)
    # print(idx.shape)
    for i in range(logits.shape[1]):
        logits_output = logits[:,i,:] # batch * class
        logit_max = torch.max(logits_output,dim=1)[0].reshape(logits_output.shape[0],1)
        ones_logits = -torch.ones_like(logits_output)*1e5
        # print(logit_max)
        logits_drop_max = torch.where(logits_output==logit_max,ones_logits,logits_output)
        # print(logits_drop_target.shape)
        logit_max_2 = torch.max(logits_drop_max,dim=1)[0].reshape(logits_output.shape[0],1)
        # print(logit_max_2)
        # print('max_shape',logit_max.shape)
        margin_with_target = logit_max - logit_max_2 # batch 
        # print(margin_with_target.shape)
        margins[:,index] = margin_with_target.squeeze()
        index +=1
    margins_mean = torch.mean(margins,1)
    return margins_mean

def variation_ratio(y):
    '''Uncertainty in Deep Learning(2016)'''
    pred = y.argmax(dim=2) # (N, T)
    mode = torch.max(pred, dim=1, keepdim=True)[0] # (N,1)

    cnt = torch.sum(pred == mode, dim=1) #(N,)
    return 1 - cnt * 1.0 / y.shape[1]

def softmax_average(y):
    ''' 
    mean of prediction, choose max softmax prob 
    Uncertainty Based Detection and Relabeling of Noisy Image Labels
    '''
    return y.mean(dim=1).max(dim=1)[0]




def distributional_uncertainty(alphas):
    '''
    Predictive Uncertainty Estimation via Prior Networks: Eq[17]
    '''
    alpha0 = torch.sum(alphas, dim=1, keepdim=True)
    probs = alphas / alpha0

    expected_entropy = expected_entropy_from_alphas(alphas, alpha0)
    entropy_of_exp = categorical_entropy_torch(probs)
    mutual_info = entropy_of_exp - expected_entropy
    return mutual_info


def expected_entropy_from_alphas(alphas, alpha0=None):
    if alpha0 is None:
        alpha0 = torch.sum(alphas, dim=1, keepdim=True)
    expected_entropy = -torch.sum(
        (alphas / alpha0) * (torch.digamma(alphas + 1) - torch.digamma(alpha0 + 1)),
        dim=1)
    return expected_entropy


def categorical_entropy_torch(probs, dim=1, keepdim=False):
    """Calculate categorical entropy purely in torch"""
    log_probs = torch.log(probs)
    log_probs = torch.where(torch.isfinite(log_probs),
                            log_probs, torch.zeros_like(log_probs))
    entropy = -torch.sum(probs * log_probs, dim=dim, keepdim=keepdim)
    return entropy


def get_tar_score(y, tar):
    col = torch.unsqueeze(tar, 1) #(N, 1)
    indices = torch.hstack([col for i in range(y.shape[1])]).unsqueeze(2) #(N, T, 1)
    return torch.gather(y, dim=2, index=indices.long()).squeeze(2) #(N, T)

def energy_score(output):
    return -torch.logsumexp(output, dim=1)


METRICS = {
    'energy_score': energy_score,
    'margin_mean': margin_mean,
    'kl_divergence': kl_divergence,
    'kl_divergence_max': kl_divergence_max,
    'predictive_entropy': predictive_entropy,
    'mean_class_variance': mean_class_variance,
    'max_variance': max_variance,
    'max_prob': max_prob,
    'reject_score': reject_score,
    'target_variance': target_variance,
    'target_mean': target_mean,
    'mutual_information': mutual_information,
    'differential_entropy': differential_entropy,
    'u_sublogic': u_sublogic,
    'kernel_distance': kernel_distance,
    'distributional_uncertainty': distributional_uncertainty,
}

if __name__ == '__main__':
    from torch.nn.functional import softmax
    y = softmax(torch.randn(20, 30, 10), dim=2)
    tar = torch.ones((20,), dtype=torch.int64)
    print(target_variance(y, tar).shape)
    # print(variation_ratio(y).shape, softmax_average(y).shape, softmax_variance(y).shape, \
    #     mutual_information(y).shape, entropy_mean(y).shape)
