import math

from ncephes import bdtr, lgam
from numba import jit
import numpy as np
import scipy.stats


@jit(nopython=True)
def inner_sum_fast(n, px, py, pz, sum_end_1, sum_end_2, tolerance=1e-18):
    result = 0.0
    
    def xlogy(x,y):
        if x == 0.0:
            return 0
        return x * np.log(y)
    
    dev = np.sqrt((n / 2) * np.log(2.0 / tolerance))
    lower_end = int(max(np.floor(n * pz - dev), 0))
    upper_end = int(min(np.ceil(n * pz + dev), n))
    
    for jj in range(lower_end, upper_end + 1):
        cur_summand = np.exp(lgam(n + 1) - lgam(jj + 1) - lgam(n - jj + 1) + xlogy(jj, pz) + xlogy(n - jj, 1 - pz))

        aux_sum_end_2 = min(sum_end_2, jj)
        cur_summand *= bdtr(aux_sum_end_2, jj, py)

        aux_sum_end_1 = min(sum_end_1, jj + 1)
        cur_summand *= 1.0 - bdtr(aux_sum_end_1 - 1, jj, px)
        
        result += cur_summand
    return result + tolerance




def inner_sum_conservative(n, px, py, pz, sum_end_1, sum_end_2):
    result = 0.0
    for jj in range(n + 1):
        cur_summand = scipy.stats.binom.pmf(jj, n, pz)
        cur_summand *= scipy.stats.binom.cdf(sum_end_2, jj, py)
        cur_summand *= 1.0 - scipy.stats.binom.cdf(sum_end_1 - 1, jj, px)
        result += cur_summand
    return result


def inner_sum(p1, p2, eta, n, alpha1, alpha2, computation_mode='fast', verbose=False):
    # p1: error probability of the first model
    # p2: error probability of the second model
    # eta: agreement probability
    # n: number of data points in the test set
    # alpha1: deviation from the mean error probability p1
    # alpha2: deviation from the mean error probability p2
    assert n >= 1
    assert p1 >= 0 and p1 <= 1.0
    assert p2 >= 0 and p2 <= 1.0
    assert eta >= 0 and eta <= 1.0
    assert alpha1 >= -p1
    assert alpha1 <= 1.0 - p1
    assert alpha2 >= -p2
    assert alpha2 <= 1.0 - p2

    p11 = (p1 + p2 + eta - 1.0) / 2.0
    p01 = (1.0 + p2 - p1 - eta) / 2.0
    p10 = (1.0 + p1 - p2 - eta) / 2.0
    p00 = (1.0 + eta - p1 - p2) / 2.0

    if p11 <= 0.0 or p01 <= 0.0 or p10 <= 0.0 or p00 <= 0.0:
        raise ValueError('Impossible combination of p1, p2, and eta')
   
    if verbose:
        print(f'p11 {p11}')
        print(f'p01 {p01}')
        print(f'p10 {p10}')
        print(f'p00 {p00}')
    
    if p11 < (p10 + p11) * (p01 + p11):
        gap = (p10 + p11) * (p01 + p11) - p11
        raise ValueError(f'Computation factorization assumption violated by {gap}')
    
    px = p11 / (p01 + p11)
    py = p11 / (p10 + p11)
    pz = (p10 + p11) * (p01 + p11) / p11
    
    sum_end_2 = math.floor(n * (p2 + alpha2))
    sum_end_1 = math.ceil(n * (p1 + alpha1))


    if computation_mode == 'fast':
        result = inner_sum_fast(n, px, py, pz, sum_end_1, sum_end_2)
    elif computation_mode == 'conservative':
        result = inner_sum_conservative(n, px, py, pz, sum_end_1, sum_end_2)
    else:
        raise ValueError(f'Unknown computation mode {computation_mode}')
    
    return result





@jit(nopython=True)
def inner_sum_naive_bayes(n, px, pw, sum_end_1, sum_end_2, k1 = 1, k2 = 1, tolerance=1e-18):
    result = 0.0
    
    def xlogy(x,y):
        if x == 0.0:
            return 0
        return x * np.log(y)

    dev = np.sqrt((n / 2) * np.log(2.0 / tolerance))
    lower_end = int(max(np.floor(n * pw - dev), 0))
    upper_end = int(min(np.ceil(n * pw + dev), n))
    
    for jj in range(lower_end, upper_end + 1):
        cur_summand = np.exp(lgam(n + 1) - lgam(jj + 1) - lgam(n - jj + 1) + xlogy(jj, pw) + xlogy(n - jj, 1 - pw))
                             
        aux_sum_end_2 = min(sum_end_2, jj)
        cur_summand *= np.power(bdtr(aux_sum_end_2, jj, px), k2)

        aux_sum_end_1 = min(sum_end_1, jj + 1)
        cur_summand *= np.power((1.0 - bdtr(aux_sum_end_1 - 1, jj, px)), k1)
        
        result += cur_summand
        
    return result + tolerance



@jit(nopython=True)
def inner_sum_exact_naive_bayes(n, px, pw, k, sum_end_left, sum_end_right):
    result = 0.0
    
    def xlogy(x,y):
        if x == 0.0:
            return 0
        return x * np.log(y)
    
    for jj in range(n + 1):
        cur_summand = np.exp(lgam(n + 1) - lgam(jj + 1) - lgam(n - jj + 1) + xlogy(jj, pw) + xlogy(n - jj, 1 - pw))

        if jj < sum_end_left:
            pE = 0.0
        else:
            aux_sum_end_right = min(sum_end_right, jj)
            pE = bdtr(aux_sum_end_right, jj, px) - bdtr(sum_end_left, jj, px)
            
        cur_summand *= (1.0 - np.power(pE, k))
        result += cur_summand
    return result


def inner_sum_exact_naive_bayes_slow(n, px, pw, k, eps, error_rate):
    result = 0.0
    for jj in range(n + 1):
        cur_res = scipy.stats.binom.pmf(jj, n, pw)
        inner_center_prob = scipy.stats.binom.cdf(n * (error_rate + eps) - 1, jj, px) - scipy.stats.binom.cdf(n * (error_rate - eps) - 1, jj, px)
        inner_term = 1.0 - math.pow(inner_center_prob, k)
        result += cur_res * inner_term
    return result


@jit(nopython=True)
def compute_num_models(rho, num_evaluations):
    result = 0.0
    for ii in range(num_evaluations):
        result += math.floor(math.pow(rho, ii))
    return result


@jit(nopython=True)
def get_rho(num_models, num_evaluations, target_difference=3, verbose=True):
    rho_lower = 1.0
    num_models_lower = compute_num_models(rho_lower, num_evaluations)
    rho_upper = np.exp(np.log(num_models)/ (2 * num_evaluations))
    num_models_upper = compute_num_models(rho_upper, num_evaluations)
    
    while num_models_upper - num_models_lower >= target_difference:
        cur_rho = (rho_lower + rho_upper) / 2.0
        cur_num_models = compute_num_models(cur_rho, num_evaluations)
        if verbose:
            print(rho_lower, rho_upper, num_models_lower, num_models_upper, cur_rho, cur_num_models)
        if cur_num_models < num_models:
            rho_lower = cur_rho
            num_models_lower = cur_num_models
        else:
            rho_upper = cur_rho
            num_models_upper = cur_num_models
    return rho_upper


def naive_bayes_prob_union(n, k, error_rate, eta, eps, num_evaluations = 1, rho = None):
    # error_rate: error probability of the models
    # eta: agreement probability
    # n: number of data points in the test set
    # k: number of models
    # eps: deviation from mean
    # max_num_binomials: the maximum number of binomials we intersect
    
    assert n >= 1
    assert error_rate >= 0 and error_rate <= 1.0
    assert eta >= 0 and eta <= 1.0

    p11 = (2 * error_rate + eta - 1.0) / 2.0
    p01 = (1.0 - eta) / 2.0
    p00 = (1.0 + eta - 2 * error_rate) / 2.0

    px = p11 / (p01 + p11)
    pw = (p01 + p11) * (p01 + p11) / p11

    assert pw <= 1.0

    if rho is None:
        rho = get_rho(k, num_evaluations, verbose = False)

    #print("rho", rho)
    
    result = 0.0
    result += scipy.stats.binom.cdf(np.floor((error_rate - eps) * n), n, error_rate)
    result += (1. - scipy.stats.binom.cdf(np.ceil((error_rate + eps) * n), n, error_rate))

    sum_end_left_1 = math.ceil(n * (error_rate - eps))
    sum_end_left_2 = math.floor(n * (error_rate - eps))
    sum_end_right_1 = math.ceil(n * (error_rate + eps))
    sum_end_right_2 = math.floor(n * (error_rate + eps))
    size_interval = 1.
    current_idx = 1

    jj = 0

    while True:
        if current_idx > k:
            break
        size_interval = np.power(rho, jj)
        
        if current_idx + np.floor(size_interval) > k:
            result += (k - current_idx + 1) * inner_sum_naive_bayes(n, px, pw, sum_end_left_1, sum_end_left_2, k1 = current_idx)
            result += (k - current_idx + 1) * inner_sum_naive_bayes(n, px, pw, sum_end_right_1, sum_end_right_2, k1 = current_idx)
        else:
            result += np.floor(size_interval) * inner_sum_naive_bayes(n, px, pw, sum_end_left_1, sum_end_left_2, k1 = current_idx)
            result += np.floor(size_interval) * inner_sum_naive_bayes(n, px, pw, sum_end_right_1, sum_end_right_2, k1 = current_idx)
        current_idx = np.floor(size_interval) + current_idx
        #print("current idx:", current_idx)
        jj += 1

    return result


def exact_naive_bayes_prob_union(n, k, error_rate, eta, eps, inner_part='fast'):
    # error_rate: error probability of the models
    # eta: agreement probability
    # n: number of data points in the test set
    # k: number of models
    # eps: deviation from mean
    # max_num_binomials: the maximum number of binomials we intersect
    
    assert n >= 1
    assert error_rate >= 0 and error_rate <= 1.0
    assert eta >= 0 and eta <= 1.0

    p11 = (2 * error_rate + eta - 1.0) / 2.0
    p01 = (1.0 - eta) / 2.0
    p00 = (1.0 + eta - 2 * error_rate) / 2.0

    px = p11 / (p01 + p11)
    pw = (p01 + p11) * (p01 + p11) / p11

    assert pw <= 1.0

    if inner_part == 'fast':
        result = 0.0

        sum_end_left = math.floor(n * (error_rate - eps))
        sum_end_right = math.floor(n * (error_rate + eps))

        return inner_sum_exact_naive_bayes(n, px, pw, k, sum_end_left, sum_end_right)
    else:
        assert inner_part == 'slow'
        return inner_sum_exact_naive_bayes_slow(n, px, pw, k, eps, error_rate)



def testable_models_binomial(n, epsilon, accuracy, delta):
    """Maximum number of models tested at confidence delta using binomial CDF.

    Arguments:
        n:          sample size
        epsilon:    confidence interval width
        accuracy:   accuracy of all of the models
        delta:      confidence bound (w.p. 1-delta)

    Return: Total number of testable models.
    """
    lower_tail = scipy.stats.binom.cdf((accuracy - epsilon) * n, n, accuracy)
    upper_tail = (1. - scipy.stats.binom.cdf((accuracy + epsilon) * n, n, accuracy))
    binomial_mistake_prob = lower_tail + upper_tail
    return delta / binomial_mistake_prob


def testable_models_similarity(n, epsilon, accuracy, delta, eta):
    """Maximum number of models tested at confidence delta using similarity.

    Arguments:
        n:          sample size
        epsilon:    confidence interval width
        accuracy:   accuracy of all of the models
        delta:      confidence bound (w.p. 1-delta)
        eta:        Pairwise similarity for all models.
    
    Return: Total number of testable models.

	"""

    ts = np.linspace(0.0, epsilon, 50)

    def prob_mistake(t, K):
        # First term
        lower_tail = scipy.stats.binom.cdf((accuracy - epsilon + t) * n, n, accuracy)
        upper_tail = (1. - scipy.stats.binom.cdf((accuracy + epsilon - t) * n, n, accuracy))
        term1 = lower_tail + upper_tail

        term2 = inner_sum(1. - accuracy, 1. - accuracy, eta, n, -epsilon + t, -epsilon)
        term2 += inner_sum(1. - accuracy, 1. - accuracy, eta, n, epsilon, epsilon - t)

        return term1 + (K - 1) * term2

    # Binary search to get the best K
    left = 1
    right = 1e16
    best_k = 0.

    print(epsilon)
    
    while True:
        #print(left, right)
        if right - left < 1:
            return best_k
        
        k = (left + right) / 2
        prob_error = np.min([prob_mistake(t, k) for t in ts])
        if prob_error > delta:
            right = k
        else:
            best_k = k
            left = k
    
    # We should never get here.
    return None


def testable_models_naive_bayes(n, epsilon, accuracy, delta, eta):
    """Maximum number of models tested at confidence delta using similarity.

    Arguments:
        n:          sample size
        epsilon:    confidence interval width
        accuracy:   accuracy of all of the models
        delta:      confidence bound (w.p. 1-delta)
        eta:        Pairwise similarity for all models.
    
    Return: Total number of testable models.

	"""
    def prob_mistake(K):
        # First term
        return exact_naive_bayes_prob_union(n, K, 1 - accuracy, eta, epsilon)

    # Binary search to get the best K
    left = 1
    right = 1e26

    best_k = 0. 
    print(epsilon)
    
    while True:
        #print(left, right)
        if right - left < 1:
            return best_k
        
        k = (left + right) / 2

        #print(left, right)
       
        prob_error = prob_mistake(k)
        if prob_error > delta:
            right = k
        else:
            best_k = k
            left = k
    
    # We should never get here.
    return None

