import os
import pickle
import argparse

import xgboost as xgb
from sklearn.ensemble import RandomForestClassifier

import sklearn
import sklearn.model_selection as ms
import numpy as np

import torch
from nnlib.nnlib import utils
from modules.bound_utils import estimate_fcmi_bound_umb

import warnings
from tqdm.auto import tqdm
import h5py

def compute_acc(preds, mask, labels):
    #labels = [y for x, y in dataset]
    preds = torch.tensor(preds)
    labels = torch.tensor(labels).long()
    indices = 2*np.arange(len(mask)) + mask
    acc = (preds[indices].argmax(dim=1) == labels[indices]).float().mean()
    return utils.to_numpy(acc)

def isnotnan(pred):
    isnotnan_ind = ~np.isnan(pred).any(axis=1)
    return pred[isnotnan_ind, :]

def calc_recab_ECE(confidences, labels, conf_cal, labels_cal, n_bins, norm='l1', strategy='label'):
    """
    Calcurating recalibrate ECE with calibration dataset.
    """
    if not torch.all(torch.abs(torch.sum(confidences, dim=1) - 1) < 1e-10):
        print("make softmax prob.")
        confidences = confidences.softmax(1)
    
    if not torch.all((confidences >= 0) & (confidences <= 1)):
        raise ValueError(f"This is not softmax prob.")
    
    confidences, _ = confidences.max(dim=1)
    confidences[labels==0] = 1 - confidences[labels==0] ## MEMO: Reverse prob. for label y=0.

    if not torch.all(torch.abs(torch.sum(conf_cal, dim=1) - 1) < 1e-10):
        print("make softmax prob.")
        conf_cal = conf_cal.softmax(1)
    
    if not torch.all((conf_cal >= 0) & (conf_cal <= 1)):
        raise ValueError(f"This is not softmax prob.")
    
    conf_cal, _ = conf_cal.max(dim=1)
    conf_cal[labels_cal==0] = 1 - conf_cal[labels_cal==0] ## MEMO: Reverse prob. for label y=0.


    with torch.no_grad():
        conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)
        count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)
        label_bin = torch.zeros(len(n_bins), device=labels.device, dtype=labels.dtype)
        
        #idx = torch.bucketize(conf_cal, n_bins, right=True) - 1
        idx = idx_bins(conf_cal, n_bins)
        bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins
        if strategy == 'label':
            bin_true = (torch.bincount(idx, weights=labels_cal, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    
        elif strategy == 'probability':
            bin_true = (torch.bincount(idx, weights=conf_cal, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels
        else:
            raise ValueError(f"Unexpected strategy: {strategy}.")
            
        with warnings.catch_warnings():
            warnings.filterwarnings('ignore')
            # fill nan by interpolation assuming smoothness
            bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \hat{\mu} in Eq.(9) of Sun et al. (2023)
        
        ## prediction based on the recalibrate function
        idx = idx_bins(confidences, n_bins)
        confidences = bin_mean[idx]

        count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences).type_as(count_bin))
        conf_bin.scatter_add_(dim=0, index=idx, src=confidences.type_as(conf_bin))
        conf_bin = torch.nan_to_num(conf_bin / count_bin)
        prop_bin = count_bin / count_bin.sum()
        
        label_bin.scatter_add_(dim=0, index=idx, src=labels)
        label_bin = torch.nan_to_num(label_bin / count_bin)
    
    if norm == 'l1':
        ece = torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)
    elif norm == 'l2':
        ece = torch.sqrt(torch.sum(torch.pow(label_bin - conf_bin, 2) * prop_bin))
    else:
        raise ValueError(f"Unexpected norm type: {norm}")
    
    return ece

def calc_ECE(confidences, labels, n_bins, norm='l1', recalibrate=False, strategy='label'):
    if not torch.all(torch.abs(torch.sum(confidences, dim=1) - 1) < 1e-10):
        print("make softmax prob.")
        confidences = confidences.softmax(1)
    
    if not torch.all((confidences >= 0) & (confidences <= 1)):
        raise ValueError(f"This is not softmax prob.")
    
    confidences, _ = confidences.max(dim=1)
    confidences[labels==0] = 1 - confidences[labels==0] ## MEMO: Reverse prob. for label y=0

    with torch.no_grad():
        conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)
        count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)
        label_bin = torch.zeros(len(n_bins), device=labels.device, dtype=labels.dtype)
        
        #idx = torch.bucketize(confidences, n_bins, right=True) - 1
        idx = idx_bins(confidences, n_bins)
        if recalibrate:
            bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins
            if strategy == 'label':
                bin_true = (torch.bincount(idx, weights=labels, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    
            elif strategy == 'probability':
                bin_true = (torch.bincount(idx, weights=confidences, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels
            else:
                raise ValueError(f"Unexpected strategy: {strategy}.")
            
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')
                # fill nan by interpolation assuming smoothness
                bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \hat{\mu} in Eq.(9) of Sun et al. (2023)
                confidences = bin_mean[idx]

        count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences).type_as(count_bin))
        conf_bin.scatter_add_(dim=0, index=idx, src=confidences.type_as(conf_bin))
        conf_bin = torch.nan_to_num(conf_bin / count_bin)
        prop_bin = count_bin / count_bin.sum()
        
        label_bin.scatter_add_(dim=0, index=idx, src=labels)
        label_bin = torch.nan_to_num(label_bin / count_bin)
    
    if norm == 'l1':
        ece = torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)
    elif norm == 'l2':
        ece = torch.sqrt(torch.sum(torch.pow(label_bin - conf_bin, 2) * prop_bin))
    else:
        raise ValueError(f"Unexpected norm type: {norm}")
    
    return ece

#def compute_ece(preds, mask, dataset, n_bins, norm='l1', recalibrate=False, strategy='label'):
def compute_ece(preds, mask, labels, n_bins, norm='l1', recalibrate=False, strategy='label', cal_data=False, preds_cal=None, labels_cal=None):
    preds = torch.tensor(preds)
    labels = torch.tensor(labels).long()
    indices = 2*np.arange(len(mask)) + mask
    if cal_data:
        ece = calc_recab_ECE(preds[indices], labels[indices], preds_cal, labels_cal, n_bins, norm=norm, strategy=strategy)
    else:
        ece = calc_ECE(preds[indices], labels[indices], n_bins, norm=norm, recalibrate=recalibrate, strategy=strategy)
    return utils.to_numpy(ece)

def compute_bins(num_bins, confidences=None, method='uniform'):
    if method == 'uniform':
        n_bins = torch.linspace(0, 1, num_bins + 1)
        n_bins[0], n_bins[-1] = 0., 1.
    elif method == 'quantile':
        if confidences.all() == None:
            raise ValueError(f"confidence values are needed.")
        n_bins = torch.tensor(np.quantile(confidences, torch.linspace(0, 1, num_bins + 1)))
        n_bins[0], n_bins[-1] = 0., 1.
    else:
        raise ValueError(f"Unexpected binning method: {method}")
    
    return n_bins

def idx_bins(confidence, n_bins):
    binids = np.minimum(np.digitize(confidence.numpy(), n_bins), len(n_bins) - 1)
    binids -= 1
    return torch.tensor(binids)

def interpolate_nan(a):
    """Linear interpolation for nan values in a 1d array.
    Nans on the boundary are filled with the nearest non-nan value.
    Slightly modified From the code in the "minimum-calibration..." NeurIPS2023.
    """
    b = a.copy()
    nans = np.isnan(b)
    i = np.arange(len(b))
    b[nans] = np.interp(i[nans], i[~nans], b[~nans])
    return torch.tensor(b).float()

def calc_bound_cal(mi, n, b):
    return np.sqrt(2*(mi + b*np.log(2)) / n)


def get_fcmi_results_for_fixed_z(n_caldata=1000, n_bins=None, data='kitti', classifier='xgboost'):
    cal_size = n_caldata ## recalibration

    random_state = 1
    S_seeds = np.arange(0,40)
    if data == 'kitti':
        n_sample = [200, 1000, 3000, 7000] ## KITTI
        files = ['kitti_all_train.data', 'kitti_all_train.labels', 'kitti_all_test.data', 'kitti_all_test.labels']
        file_path = os.getcwd() + '/kitti_features/'
        
        X_train = np.loadtxt(os.path.join(file_path, files[0]), np.float64, skiprows=1)
        y_train = np.loadtxt(os.path.join(file_path, files[1]), np.int32, skiprows=1)
        X_test = np.loadtxt(os.path.join(file_path, files[2]), np.float64, skiprows=1)
        y_test = np.loadtxt(os.path.join(file_path, files[3]), np.int32, skiprows=1)
        
        y_train = np.where(y_train > 0, 1, 0)
        y_test = np.where(y_test > 0, 1, 0)
       
        if n_bins == None:
            n_bins = int(X_train.shape[0] ** (1/3))
            n_bins_eval = n_bins

    elif data == 'pcam':
        n_sample = [500, 3000, 7000, 10000] ## Pcam
        files = ["camelyonpatch_level_2_split_valid_x.h5", "camelyonpatch_level_2_split_valid_y.h5", "camelyonpatch_level_2_split_test_x.h5", "camelyonpatch_level_2_split_test_y.h5"]
        file_path = os.getcwd() + '/pcam/'
        
        with h5py.File(os.path.join(file_path, files[0]), 'r') as hf:
            X_train = hf["x"][:]
        X_train = np.dot(X_train[..., :3], [0.299, 0.587, 0.114]) ## gray scale
        X_train = X_train.reshape(np.shape(X_train)[0], -1) / 255 ## normalization
        
        with h5py.File(os.path.join(file_path, files[1]), 'r') as hf:
            y_train = hf["y"][:]
        y_train = y_train.flatten()
        
        with h5py.File(os.path.join(file_path, files[2]), 'r') as hf:
            X_test = hf["x"][:]
        X_test = np.dot(X_test[..., :3], [0.299, 0.587, 0.114]) ## gray scale
        X_test = X_test.reshape(np.shape(X_test)[0], -1) / 255 ## normalization
        
        with h5py.File(os.path.join(file_path, files[3]), 'r') as hf:
            y_test = hf["y"][:]
            y_test = y_test.flatten()
        
        if n_bins == None:
            n_bins = int(X_train.shape[0] ** (1/3))
            n_bins_eval = n_bins
    else:
        raise ValueError(f"Unexpected dataset: {data}")
    
    train_accs = []
    val_accs = []
    preds = []
    masks = []
    labels = []
    
    ece_gap_umb = []
    ece_gap_umb_proposed = []
    bins_umb = []
    bins_list_umb = []
    gap_res = []
    gap_res_proposed = []
    bounds = []
    for n in tqdm(n_sample):
        for seed in tqdm(S_seeds):
            np.random.seed(seed)
            ## preparing all datasets (train/test)
            all_indices = np.random.choice(X_train.shape[0], size=2*n, replace=False)
            X, y = X_train[all_indices], y_train[all_indices]

            if classifier == 'xgboost':
                model = xgb.XGBClassifier(booster="gbtree", n_estimators=100, random_state=random_state, n_jobs=-1)
            elif classifier == 'randomforest':
                model = RandomForestClassifier(n_estimators=100, criterion="gini", min_samples_split=2, bootstrap=True, n_jobs=-1, random_state=random_state)
            else:
                raise ValueError(f"Unexpected model: {model}")
            
            cur_mask = np.random.randint(2, size=(n,)) ## Ber(1/2)
            train_indices = 2*np.arange(n) + cur_mask
            x_tr, y_tr = X[train_indices], y[train_indices]
            model.fit(X=x_tr, y=y_tr)
            
            cur_preds = isnotnan(model.predict_proba(X))
            preds.append(torch.tensor(cur_preds))
            masks.append(torch.tensor(cur_mask))
            labels.append(torch.tensor(y))
            
            ## Accuracy
            cur_train_acc = compute_acc(preds=cur_preds, mask=cur_mask, labels=y)
            cur_val_acc = compute_acc(preds=cur_preds, mask=1-cur_mask, labels=y)
            train_accs.append(cur_train_acc)
            val_accs.append(cur_val_acc)
            
            ## ECE w./ recalibration (UMB)
            conf = torch.tensor(cur_preds[train_indices]).softmax(1).max(1).values
            cur_bins_umb = compute_bins(num_bins=n_bins, confidences=conf, method='quantile') ## for L1-ECE
            bins_list_umb.append(cur_bins_umb)
            cur_bins = idx_bins(conf, cur_bins_umb)
            bins_umb.append(cur_bins.numpy())
            
            # UMB on recalibration data
            cal_indices = np.random.choice(X_test.shape[0], size=cal_size, replace=False)
            x_cal, y_cal = X_test[cal_indices], y_test[cal_indices]
            cal_preds = isnotnan(model.predict_proba(x_cal))
            
            ece_tr = compute_ece(cur_preds, mask=cur_mask, labels=y_train, n_bins=cur_bins_umb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(cal_preds), labels_cal=torch.tensor(y_cal).long())
            ece_tes = compute_ece(cur_preds, mask=1-cur_mask, labels=y_train, n_bins=cur_bins_umb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(cal_preds), labels_cal=torch.tensor(y_cal).long())
            ece_gap_umb.append(np.abs(ece_tr - ece_tes))
            
            # UMB on full training dataset (proposed recalibration)
            ece_tr = compute_ece(cur_preds, mask=cur_mask, labels=y_train, n_bins=cur_bins_umb, recalibrate=True)
            ece_tes = compute_ece(cur_preds, mask=1-cur_mask, labels=y_train, n_bins=cur_bins_umb, recalibrate=True)
            ece_gap_umb_proposed.append(np.abs(ece_tr - ece_tes))

    c = 0
    for i in tqdm([0, 40, 80, 120]):
        gap_cal = [np.array(ece_gap_umb[i:(i+40)]).mean(), np.array(ece_gap_umb[i:(i+40)]).std()]
        gap_res.append(gap_cal)
        gap_proposed = [np.array(ece_gap_umb_proposed[i:(i+40)]).mean(), np.array(ece_gap_umb_proposed[i:(i+40)]).std()]
        gap_res_proposed.append(gap_proposed)

        _, mis_list = estimate_fcmi_bound_umb(masks=masks[i:(i+40)], preds=preds[i:(i+40)], labels=labels[i:(i+40)], bins=bins_umb[i:(i+40)], num_examples=n_sample[c], num_classes=2, n_bins=n_bins_eval, bins_list=bins_list_umb[i:(i+40)],
                                       norm='l1', loss='reuse', recalibration=True, verbose=False, return_list_of_mis=True)
        bound_value = np.array([calc_bound_cal(mis_list[j], n_sample[c], b=int(n_sample[c] ** (1/3))) for j in range(len(mis_list))])
        bounds.append(bound_value)
        c += 1
    
    return {"ece_gap": gap_res,
            "ece_gap_proposed": gap_res_proposed,
            "bound_values": bounds
            }

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, required=True)
    parser.add_argument('--results_dir', type=str, default='results')
    parser.add_argument('--model', type=str, default='xgboost', help='models')
    parser.add_argument('--n_recalibration', type=float, default=1000, help='number of samples for recalibration')
    parser.set_defaults(parse=True)
    args = parser.parse_args()
    print(args)

    results = get_fcmi_results_for_fixed_z(n_caldata=args.n_recalibration, data=args.exp_name, classifier=args.model)
    results_file_path = os.path.join(args.results_dir, args.exp_name, 'results_{}.pkl'.format(args.model))
    with open(results_file_path, 'wb') as f:
        pickle.dump(results, f)

if __name__ == '__main__':
    main()