import numpy as np
import torch
from torch import softmax
import warnings

from nnlib.nnlib import utils
from methods import LangevinDynamics
from torchmetrics.classification import MulticlassCalibrationError as ECE

# Import added
import scipy.optimize as opt
from sklearn.feature_selection import mutual_info_classif
from sklearn.neighbors import KernelDensity

def discrete_mi_est(xs, ys, nx=2, ny=2):
    prob = np.zeros((nx, ny))
    for a, b in zip(xs, ys):
        prob[a,b] += 1.0/len(xs)
    pa = np.sum(prob, axis=1)
    pb = np.sum(prob, axis=0)
    mi = 0
    for a in range(nx):
        for b in range(ny):
            if prob[a,b] < 1e-9:
                continue
            mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))
    return max(0.0, mi)

def mutual_information_kernel_density(X, y, bandwidth=1.0, num_points=100):
    X_y = np.column_stack((X, y))
    
    # KDE
    kde_X = KernelDensity(bandwidth=bandwidth, kernel='gaussian')
    kde_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')
    kde_X_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')

    kde_X.fit(X)
    kde_Y.fit(y.reshape(-1, 1))
    kde_X_Y.fit(X_y)

    density_X = np.exp(kde_X.score_samples(X))
    density_Y = np.exp(kde_Y.score_samples(y.reshape(-1, 1)))
    density_X_Y = np.exp(kde_X_Y.score_samples(X_y))
    
    #mi = (density_X_Y * np.log(density_X_Y) - density_X_Y * np.log(density_X*density_Y)).sum()
    mi = (density_X_Y * np.log(density_X_Y / (density_X*density_Y))).sum()

    return max(0.0, mi)

def mutual_information_kde_bin(loss_bin, masks, bandwidth=1.):
    # prepare data
    n_bins = len(loss_bin)
    X_y = np.array([np.column_stack((loss_bin[i], masks)) for i in range(n_bins)]).reshape(-1,3)
    l_mat = loss_bin.reshape(-1,2)
    ms_mat = np.tile(masks, n_bins).reshape(-1,1)
    
    # KDE
    kde_X = KernelDensity(bandwidth=bandwidth, kernel='gaussian')
    kde_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')
    kde_X_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')

    kde_X.fit(l_mat)
    kde_Y.fit(ms_mat)
    kde_X_Y.fit(X_y)

    density_X = np.exp(kde_X.score_samples(l_mat))
    density_Y = np.exp(kde_Y.score_samples(ms_mat))
    density_X_Y = np.exp(kde_X_Y.score_samples(X_y))
    mi = (density_X_Y * np.log(density_X_Y / (density_X*density_Y))).sum()

    return max(0.0, mi)


def interpolate_nan(a):
    """Linear interpolation for nan values in a 1d array.
    Nans on the boundary are filled with the nearest non-nan value.
    Slightly modified From the code in the "minimum-calibration..." NeurIPS2023.
    """
    b = a.copy()
    nans = np.isnan(b)
    i = np.arange(len(b))
    b[nans] = np.interp(i[nans], i[~nans], b[~nans])
    return torch.tensor(b).float()

def idx_bins(confidence, n_bins):
    binids = np.minimum(np.digitize(confidence.numpy(), n_bins), len(n_bins) - 1)
    binids -= 1
    return torch.tensor(binids)

def calc_recalibration(ps, n_bins):
    #idx = torch.bucketize(ps, n_bins, right=True) - 1
    idx = idx_bins(ps, n_bins)
    bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(ps.device) ## the number of samples per bins
    bin_true = torch.bincount(idx, weights=ps, minlength=len(n_bins)-1).float().to(ps.device) ## the number of samples per bins weighted by labels
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore')
        bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \hat{\mu} in Eq.(9) of Sun et al. (2023)
    return bin_mean[idx]

def estimate_fcmi_bound_uwb(masks, preds, labels, bins, num_examples, num_classes, n_bins, bins_list=None, 
                                       norm='l1', loss='diff', recalibration=False, verbose=False, return_list_of_mis=False):
    """
    Estimating our bound's value (Uniform width binning).
    """
    if type(n_bins) == int:
        B = n_bins
    else:
        B = len(n_bins) - 1
    #list_of_mis = np.zeros(B)
    list_of_mis = []
    bound = 0.0
    for idx in range(num_examples):
        #ms = [p[idx] for p in masks]
        ms = np.array([p[idx] for p in masks])
        ps = [p[2*idx:2*idx+2] for p in preds]
        ls = [l[2*idx:2*idx+2] for l in labels]
        bs = np.array([b[idx] for b in bins])
        
        losses = []
        if loss == 'diff' or loss == 'diff_recalibrate' or loss == 'reuse':
            loss_bin = torch.zeros(B,len(ps),2)
        for i in range(len(ps)):
            if loss == "diff":
                if not torch.all(torch.abs(torch.sum(ps[i], dim=1) - 1) < 1e-10):
                    ps[i] = torch.max(softmax(ps[i], 1), dim=1).values ## predictive values
                else:
                    ps[i] = torch.max(ps[i], dim=1).values ## predictive values
                ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0] ## f(x) is the predictive prob. for y=1.
                l = ls[i] - ps[i]
                #losses.append(l)
                loss_bin[bs[i]][i] = l
            elif loss == "softmax":
                if not torch.all(torch.abs(torch.sum(ps[i], dim=1) - 1) < 1e-10):
                    ps[i] = torch.max(softmax(ps[i], 1), dim=1).values ## predictive values
                else:
                    ps[i] = torch.max(ps[i], dim=1).values ## predictive values
                ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0] ## f(x) is the predictive prob. for y=1.
                losses.append(ps[i])
            elif loss == 'reuse':
                loss_bin[bs[i]][i] = ls[i]
                #ps[i] = torch.argmax(softmax(ps[i], 1), dim=1) ## predictive values
                #losses.append(ps[i])
            elif loss == 'diff_recalibrate':
                if not torch.all(torch.abs(torch.sum(ps[i], dim=1) - 1) < 1e-10):
                    ps[i] = torch.max(softmax(ps[i], 1), dim=1).values ## predictive values
                else:
                    ps[i] = torch.max(ps[i], dim=1).values ## predictive values
                ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0] ## f(x) is the predictive prob. for y=1.
                #ps[i] = calc_recalibration(ps[i], n_bins)
                ps[i] = calc_recalibration(ps[i], bins_list[i])
                l = ls[i] - ps[i]
                #losses.append(ls[i] - ps[i])
                loss_bin[bs[i]][i] = l
            else:
                raise ValueError(f"Unexpected loss type: {loss}")
            
        if loss == 'diff' or loss == 'diff_recalibrate' or loss == 'reuse':
            #cur_mi = np.array([mutual_information_kernel_density(loss_bin[i], ms) for i in range(B)]).sum()
            #cur_mi = mutual_information_kde_bin(loss_bin, ms)
            cur_mi = max(0., mutual_info_classif(loss_bin.reshape(-1,2), np.tile(ms, B), discrete_features=[False, False]).sum())
        else:
            ps = torch.concat(losses).reshape(-1,2).numpy()
            #cur_mi = mutual_information_kernel_density(ps, ms)
            cur_mi = max(0., mutual_info_classif(ps, ms, discrete_features=[False, False]).sum())
        list_of_mis.append(cur_mi)
        bound += cur_mi
        
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    
    if norm == 'l1':
        if recalibration:
            #bound = np.sqrt((2*B*(bound + np.log(B)))/num_examples)
            bound = np.sqrt((2*(bound + B*np.log(2))) / num_examples)
        else:
            #bound = np.sqrt((2*B*(bound + np.log(B)))/num_examples)
            bound = np.sqrt((8*(bound + B*np.log(2))) / num_examples)

    elif norm == 'l2':
        if recalibration:
            bound = np.sqrt(4*(list_of_mis.sum() + B*np.log(3)) / num_examples)  + 1/B + np.sqrt(B/(4*num_examples))
        else:
            bound = np.sqrt(4*(list_of_mis.sum() + B*np.log(3)) / num_examples)  + np.sqrt(B/(4*num_examples))
    else:
        raise ValueError(f"Unexpected norm type: {norm}")

    if return_list_of_mis:
        return bound, list_of_mis

    return bound

def estimate_fcmi_bound_umb(masks, preds, labels, bins, num_examples, num_classes, n_bins, bins_list=None,
                                       norm='l1', loss='diff', recalibration=False, verbose=False, return_list_of_mis=False):
    """
    Estimating our bound's value (Uniform mass binning).
    """
    if type(n_bins) == int:
        B = n_bins
    else:
        B = len(n_bins)
    #list_of_mis = np.zeros(B)
    list_of_mis = []
    bound = 0.0
    for idx in range(num_examples):
        #ms = [p[idx] for p in masks]
        ms = np.array([p[idx] for p in masks])
        ps = [p[2*idx:2*idx+2] for p in preds]
        ls = [l[2*idx:2*idx+2] for l in labels]
        bs = np.array([b[idx] for b in bins])
        
        losses = []
        if loss == 'diff' or loss == 'diff_recalibrate' or loss == 'reuse':
            loss_bin = torch.zeros(B,len(ps),2)
        for i in range(len(ps)):
            if loss == "diff":
                if not torch.all(torch.abs(torch.sum(ps[i], dim=1) - 1) < 1e-10):
                    ps[i] = torch.max(softmax(ps[i], 1), dim=1).values ## predictive values
                else:
                    ps[i] = torch.max(ps[i], dim=1).values ## predictive values
                ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0] ## f(x) is the predictive prob. for y=1.
                l = ls[i] - ps[i]
                #losses.append(l)
                loss_bin[bs[i]][i] = l
            elif loss == "softmax":
                if not torch.all(torch.abs(torch.sum(ps[i], dim=1) - 1) < 1e-10):
                    ps[i] = torch.max(softmax(ps[i], 1), dim=1).values ## predictive values
                else:
                    ps[i] = torch.max(ps[i], dim=1).values ## predictive values
                ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0] ## f(x) is the predictive prob. for y=1.
                losses.append(ps[i])
            elif loss == 'reuse':
                loss_bin[bs[i]][i] = ls[i]
            elif loss == 'diff_recalibrate':
                if not torch.all(torch.abs(torch.sum(ps[i], dim=1) - 1) < 1e-10):
                    ps[i] = torch.max(softmax(ps[i], 1), dim=1).values ## predictive values
                else:
                    ps[i] = torch.max(ps[i], dim=1).values ## predictive values
                ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0] ## f(x) is the predictive prob. for y=1.
                ps[i] = calc_recalibration(ps[i], bins_list[i])
                l = ls[i] - ps[i]
                loss_bin[bs[i]][i] = l
            else:
                raise ValueError(f"Unexpected loss type: {loss}")

        if loss == 'diff' or loss == 'diff_recalibrate' or loss == 'reuse':
            #cur_mi = np.array([mutual_information_kernel_density(loss_bin[i], ms) for i in range(B)]).sum()
            #cur_mi = mutual_information_kde_bin(loss_bin, ms)
            cur_mi = max(0., mutual_info_classif(loss_bin.reshape(-1,2), np.tile(ms, B), discrete_features=[False, False]).sum())
        else:
            ps = torch.concat(losses).reshape(-1,2).numpy()
            #cur_mi = mutual_information_kernel_density(ps, ms)
            cur_mi = max(0., mutual_info_classif(ps, ms, discrete_features=[False, False]).sum())
        list_of_mis.append(cur_mi)
        bound += cur_mi
        
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    
    if norm == 'l1':
        if recalibration:
            #bound = np.sqrt((2*B*(bound + np.log(B)))/num_examples)
            bound = np.sqrt((2*(bound + B*np.log(2))) / num_examples)
        else:
            #bound = np.sqrt((2*B*(bound + np.log(B)))/num_examples)
            bound = np.sqrt((8*(bound + B*np.log(2))) / num_examples)
    elif norm == 'l2':
        raise ValueError(f"Haven't yet implement the bound for the following norm type: {norm}")
        #if recalibration:
        #    bound = np.sqrt(4*(list_of_mis.sum() + B*np.log(3)) / num_examples)  + 1/B + np.sqrt(B/(4*num_examples))
        #else:
        #    bound = np.sqrt(4*(list_of_mis.sum() + B*np.log(3)) / num_examples)  + np.sqrt(B/(4*num_examples))
    else:
        raise ValueError(f"Unexpected norm type: {norm}")

    if return_list_of_mis:
        return bound, list_of_mis

    return bound


def estimate_fcmi_bound_baseline(masks, preds, labels, num_examples, num_classes,
                                       loss='diff', verbose=False, return_list_of_mis=False):
    """
    Estimating our baseline bound's value.
    """
    
    bound = 0.0
    list_of_mis = []
    for idx in range(num_examples):
        ms = [p[idx] for p in masks]
        ps = [p[2*idx:2*idx+2] for p in preds]
        ls = [l[2*idx:2*idx+2] for l in labels]
        if loss == 'diff':
            losses = []
        for i in range(len(ps)):
            if not torch.all(torch.abs(torch.sum(ps[i], dim=1) - 1) < 1e-10):
                ps[i] = torch.max(softmax(ps[i], 1), dim=1).values ## predictive values
            else:
                ps[i] = torch.max(ps[i], dim=1).values ## predictive values
            
            if loss == 'diff':
                ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0] ## f(x) is the predictive prob. for y=1.
                losses.append((ls[i] - ps[i])**2)

        if loss == 'diff':
            ps = torch.concat(losses).reshape(-1,2).numpy()
        elif loss == 'softmax':
            ps = torch.concat(ps).reshape(-1,2).numpy()
        else:
            raise ValueError(f"Unexpected loss type: {loss}")
        
        cur_mi = mutual_info_classif(ps, ms, discrete_features=[False, False]).sum() ## masks are discrete
        list_of_mis.append(cur_mi)
        bound += cur_mi
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    bound = np.sqrt(4*(bound + np.log(3))/num_examples)
    #bound = np.sqrt(2*bound) / num_examples

    if return_list_of_mis:
        return bound, list_of_mis

    return bound

def estimate_fcmi_bound_classification(masks, preds, num_examples, num_classes, knn=False,
                                       verbose=False, return_list_of_mis=False):
    bound = 0.0
    list_of_mis = []
    for idx in range(num_examples):
        ms = [p[idx] for p in masks]
        ps = [p[2*idx:2*idx+2] for p in preds]
        for i in range(len(ps)):
            ps[i] = torch.argmax(ps[i], dim=1)
            if knn == False:
                ps[i] = num_classes * ps[i][0] + ps[i][1]
                ps[i] = ps[i].item()
        if knn:
            ps = torch.concat(ps).reshape(-1,2).numpy()
            #cur_mi = mutual_information_kernel_density(ps, np.array(ms))
            cur_mi = max(0., mutual_info_classif(ps, np.array(ms), discrete_features=[False, False]).sum())
        else:
            cur_mi = discrete_mi_est(ms, ps, nx=2, ny=num_classes**2)
        #ps = torch.concat(ps).reshape(-1,2).numpy()
        #cur_mi = mutual_info_classif(ps, ms, discrete_features=[True, True]).sum() ## masks are discrete
        list_of_mis.append(cur_mi)
        bound += np.sqrt(2 * cur_mi)
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    bound *= 1/num_examples

    if return_list_of_mis:
        return bound, list_of_mis

    return bound

def kl(q,p):
    if q>0:
        return q*np.log(q/p) + (1-q)*np.log( (1-q)/(1-p) )
    else:
        return np.log( 1/(1-p) )

# Function added
def estimate_interp_bound_classification(masks, preds, num_examples, num_classes, train_acc,
                                       verbose=False, return_list_of_mis=False):
    RHS = 0.0
    list_of_mis = []
    for idx in range(num_examples):
        ms = [p[idx] for p in masks]
        ps = [p[2*idx:2*idx+2] for p in preds]
        for i in range(len(ps)):
            ps[i] = torch.argmax(ps[i], dim=1)
            #ps[i] = num_classes * ps[i][0] + ps[i][1]
            #ps[i] = ps[i].item()
        #cur_mi = discrete_mi_est(ms, ps, nx=2, ny=num_classes**2)
        ps = torch.concat(ps).reshape(-1,2).numpy()
        cur_mi = mutual_info_classif(ps, ms, discrete_features=[True, True]).sum() ## masks are discrete
        list_of_mis.append(cur_mi)
        RHS += cur_mi
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    RHS *= 1/num_examples

    Rhat = 1-train_acc
    if Rhat == 0:
        bound = RHS/np.log(2)
    else:
        bound = 1
    if return_list_of_mis:
        return bound, list_of_mis

    return bound

# Function added
def estimate_kl_bound_classification(masks, preds, num_examples, num_classes, train_acc,
                                       verbose=False, return_list_of_mis=False):
    RHS = 0.0
    list_of_mis = []
    for idx in range(num_examples):
        ms = [p[idx] for p in masks]
        ps = [p[2*idx:2*idx+2] for p in preds]
        for i in range(len(ps)):
            ps[i] = torch.argmax(ps[i], dim=1)
            #ps[i] = num_classes * ps[i][0] + ps[i][1]
            #ps[i] = ps[i].item()
        #cur_mi = discrete_mi_est(ms, ps, nx=2, ny=num_classes**2)
        ps = torch.concat(ps).reshape(-1,2).numpy()
        cur_mi = mutual_info_classif(ps, ms, discrete_features=[True, True]).sum() ## masks are discrete
        list_of_mis.append(cur_mi)
        RHS += cur_mi
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    RHS *= 1/num_examples

    Rhat = 1-train_acc
    # Constraints are expressions that should be non-negative
    # Below factors guarantee R<=1, R>=0,and bound satisfied
    def con(R):
        return (RHS-kl(Rhat,Rhat/2 + R/2))*R*(1-R)

    # Minimize -R to find biggest R that satisfies constraints
    objective = lambda R: -R
    cons = ({'type': 'ineq', 'fun' : con})
    results = opt.minimize(objective,x0=0.5,
    constraints = cons,
    options = {'disp':False})

    bound = results.x[0]
    if return_list_of_mis:
        return bound, list_of_mis

    return bound

# Function added
def estimate_lg_bound_classification(masks, preds, num_examples, num_classes, train_acc,
                                       verbose=False, return_list_of_mis=False):
    RHS = 0.0
    list_of_mis = []
    for idx in range(num_examples):
        ms = [p[idx] for p in masks]
        ps = [p[2*idx:2*idx+2] for p in preds]
        for i in range(len(ps)):
            ps[i] = torch.argmax(ps[i], dim=1)
            #ps[i] = num_classes * ps[i][0] + ps[i][1]
            #ps[i] = ps[i].item()
        #cur_mi = discrete_mi_est(ms, ps, nx=2, ny=num_classes**2)
        ps = torch.concat(ps).reshape(-1,2).numpy()
        cur_mi = mutual_info_classif(ps, ms, discrete_features=[True, True]).sum() ## masks are discrete
        list_of_mis.append(cur_mi)
        RHS += cur_mi
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    RHS *= 1/num_examples

    Rhat = 1-train_acc
    def con(x):
        return (-x[0]*(1-x[1]) - (np.exp(x[0]) - 1 - x[0] ) * ( 1 + x[1]**2 ))
    objective = lambda x: x[1]*Rhat + RHS/x[0]
    cons = ({'type': 'ineq', 'fun' : con})
    bnds = ((0, 0.37),(1,np.inf))
    results = opt.minimize(objective,x0=[3,2],
                           constraints = cons,
                           bounds = bnds,
                           options = {'disp':True})

    bound = results.x[1]*Rhat + RHS/results.x[0]

    if return_list_of_mis:
        return bound, list_of_mis

    return bound





def estimate_sgld_bound(n, batch_size, model):
    """ Computes the bound of Negrea et al. "Information-Theoretic Generalization Bounds for
    SGLD via Data-Dependent Estimates". Eq (6) of https://arxiv.org/pdf/1911.02151.pdf.
    """
    assert isinstance(model, LangevinDynamics)
    assert model.track_grad_variance
    T = len(model._grad_variance_hist)
    assert len(model._lr_hist) == T + 1
    assert len(model._beta_hist) == T + 1
    ret = 0.0
    for t in range(1, T):  # skipping the first iteration as grad_variance was not tracked for it
        ret += model._lr_hist[t] * model._beta_hist[t] / 4.0 * model._grad_variance_hist[t-1]
    ret = np.sqrt(utils.to_numpy(ret))
    ret *= np.sqrt(n / 4.0 / batch_size / (n-1) / (n-1))
    return ret
