import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


class UpperBound(nn.Module):
    def __init__(self, discriminator):
        super(UpperBound, self).__init__()
        self.discriminator = discriminator

    def forward(self, true_samples, gen_samples, b=1e-10):
        assert not true_samples.requires_grad
        ar = self.discriminator.acceptance_ratio(gen_samples, true_samples)
        loss = torch.sum(torch.log(ar) + ar)
        return loss


class ConventionalCrossEntropy(nn.Module):
    def __init__(self, discriminator):
        super(ConventionalCrossEntropy, self).__init__()
        self.discriminator = discriminator

    def forward(self, true_samples, gen_samples, b=1e-10):
        assert not true_samples.requires_grad
        loss = -torch.log(self.discriminator.d(true_samples))-torch.log(1.0-self.discriminator.d(gen_samples))
        loss = torch.sum(loss)
        return loss


class MarkovCrossEntropy(nn.Module):
    def __init__(self, discriminator):
        super(MarkovCrossEntropy, self).__init__()
        self.discriminator = discriminator

    def forward(self, true_samples, gen_samples, b=1e-10):
        assert not true_samples.requires_grad
        x = self.discriminator(true_samples)
        y = self.discriminator(gen_samples)
        d_xy = torch.sigmoid(x - y)
        d_yx = torch.sigmoid(y - x)
        loss = -torch.log(d_xy)-torch.log(1.0-d_yx)
        loss = torch.sum(loss)
        return loss


class ProposalObjectiveMarkov(nn.Module):
    def __init__(self, discriminator):
        super(ProposalObjectiveMarkov, self).__init__()
        self.discriminator = discriminator

    def forward(self, true_samples, gen_samples, b=1e-10):
        assert not true_samples.requires_grad
        d_xy = self.discriminator(true_samples, gen_samples, b=b)
        d_yx = self.discriminator(gen_samples, true_samples, b=b)
        loss = torch.mean(torch.log(d_xy / d_yx))
        return loss


class DiscriminatorObjective(nn.Module):
    def __init__(self, discriminator):
        super(DiscriminatorObjective, self).__init__()
        self.discriminator = discriminator

    def forward(self, true_samples, gen_samples, b=1e-10):
        assert not true_samples.requires_grad
        assert not gen_samples.requires_grad
        d_xy = self.discriminator(true_samples, gen_samples, b=b)
        d_yx = self.discriminator(gen_samples, true_samples, b=b)
        loss = torch.mean(d_yx / d_xy)
        return loss


class ELBO_VAE(nn.Module):
    def __init__(self):
        super(ELBO_VAE, self).__init__()

    def forward(self, x, recon_x, mu, logs2):
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        # MSE = torch.sum(0.5*(recon_x-x)**2/0.3)
        KLD = -0.5 * torch.sum(1 + logs2 - mu.pow(2) - logs2.exp())
        return BCE + KLD
