import torch
from torch import sigmoid
from torch.nn.functional import binary_cross_entropy_with_logits as BCEWL
from torch.nn.functional import binary_cross_entropy as BCE


def err(logits_x, logits_y):
    # note: this returns a vector! Not mean!
    return ((logits_x > 0) != (logits_y > 0)).to(torch.float32)


class KDLoss:
    # in two class classifications
    def __init__(self, rho=0.0, T=1.0):
        # we choose hard loss by default
        self.rho = rho
        self.T = T

    def __call__(self, input, target):
        # input and target are both logits
        return self.rho * BCEWL(input / self.T, sigmoid(target / self.T)
                                ) + (1 - self.rho) * BCEWL(input, (target > 0).to(target.dtype))


class KDLoss_TS:
    # in two class classifications
    def __init__(self, rho=0.0, T=1.0):
        # we choose hard loss by default
        self.rho = rho
        self.T = T

    def __call__(self, input, soft_logits, hard_labels):
        # input and target are both logits
        return self.rho * BCEWL(input / self.T, sigmoid(soft_logits / self.T)
                                ) + (1 - self.rho) * BCEWL(input, hard_labels)


class KDLoss_min_T1:
    # minimum loss
    def __init__(self, rho=0.0):
        # we choose hard loss by default
        self.rho = rho

    def __call__(self, target, hard_labels):

        sigma = self.rho * sigmoid(target) + (1 - self.rho) * hard_labels

        return self.rho * BCE(sigma, sigmoid(target),
                              reduce=False) + (1 - self.rho) * BCE(sigma, hard_labels, reduce=False)


class KDLoss_min:
    # minimum loss
    def __init__(self, rho=0.0, T=1.0):
        # we choose hard loss by default
        self.rho = rho
        self.T = T

    def __call__(self, target):
        # input = target
        if self.rho == 0.0:
            return torch.zeros(target.shape[0])

        input = f_eff(target, self.rho, self.T)
        return self.rho * BCEWL(
            input / self.T, sigmoid(target / self.T), reduce=False
        ) + (1 - self.rho) * BCEWL(input, (target > 0).to(target.dtype), reduce=False)


class KDLoss_min_TS:
    # minimum loss
    def __init__(self, rho=0.0, T=1.0):
        # we choose hard loss by default
        self.rho = rho
        self.T = T

    def __call__(self, soft_logits, hard_labels):
        # input = target
        if self.rho == 0.0:
            return torch.zeros(hard_labels.shape[0])

        input = f_eff_TS(soft_logits, hard_labels, self.rho, self.T)
        return self.rho * BCEWL(input / self.T, sigmoid(soft_logits / self.T), reduce=False
                                ) + (1 - self.rho) * BCEWL(input, hard_labels, reduce=False)


def inverse_sigmoid(x):
    return torch.log(x / (1-x))


def sigmoid_derivative(x):
    out = torch.exp(-x)
    return torch.pow(1 + out, -2) * out


# this returns the RHS of iteration for solving effective labels
def d1(x_e, x, rho, T):
    with torch.no_grad():
        soft = torch.sigmoid(x / T) - torch.sigmoid(x_e / T)
        hard = (x > 0).to(x.dtype) - torch.sigmoid(x_e)
        out = rho*soft/T + (1-rho) * hard
    return out


def d1_TS(x_e, soft_logits, hard_labels, rho, T):
    with torch.no_grad():
        soft = torch.sigmoid(soft_logits / T) - torch.sigmoid(x_e / T)
        hard = hard_labels - torch.sigmoid(x_e)
        out = rho*soft/T + (1-rho) * hard
    return out


def d2(x_e, rho, T):
    with torch.no_grad():
        soft = torch.sigmoid(x_e / T)
        hard = torch.sigmoid(x_e)
        out = -rho * soft * (1-soft) / (T*T) - (1-rho) * hard * (1-hard)
    return out


def f_eff(x, rho, T, iter=30):

    if rho == 0.0:
        return torch.sign(x) * torch.inf

    if T == 1.0 and rho == 1.0:
        return x

    # solved with Newton method, using d1 and d2.
    else:
        with torch.no_grad():
            x_e = torch.tensor(x)
            for i in range(iter):
                x_e = x_e - d1(x_e, x, rho, T) / d2(x_e, rho, T)

        return x_e


def f_eff_TS(soft_logits, hard_labels, rho, T, iter=30):
    # the version where soft labels is generated by teacher, and may be incorrect
    if rho == 0.0:
        return torch.sign(hard_labels) * torch.inf

    if T == 1.0 and rho == 1.0:
        return soft_logits

    # solved with Newton method, using d1 and d2.
    else:
        with torch.no_grad():
            x_e = torch.tensor(soft_logits)
            for i in range(iter):
                x_e = x_e - d1_TS(x_e, soft_logits, hard_labels, rho, T) / d2(x_e, rho, T)

        return x_e


def f_eff_GD(soft_logits, hard_labels, rho, T, iter=100, lr=10.0):
    # This method use gradient descent, instead of newton
    if rho == 0.0:
        return torch.sign(hard_labels) * torch.inf

    if T == 1.0 and rho == 1.0:
        return soft_logits

    else:
        interv = iter // 20
        with torch.no_grad():
            x_e = soft_logits.clone().detach()
            for i in range(iter):
                x_e = x_e + lr * d1_TS(x_e, soft_logits, hard_labels, rho, T)
                if i % interv == 0:
                    lr = lr * 0.95

        return x_e
