
import math


def max_entropy(n, k):
    """The maximum entropy we could get with n units and k winners."""
    s = float(k) / n
    if 0.0 < s < 1.0:
        entropy = -s * math.log(s, 2) - (1 - s) * math.log(1 - s, 2)
    else:
        entropy = 0

    return n * entropy


def binary_entropy(x):
    """Calculate entropy for a list of binary random variables.

    :param x: (torch tensor) the probability of the variable to be 1.
    :return: entropy: (torch tensor) entropy, sum(entropy)
    """
    entropy = -x * x.log2() - (1 - x) * (1 - x).log2()
    entropy[x * (1 - x) == 0] = 0
    return entropy, entropy.sum()
