class DiscriminatorCifar(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super(DiscriminatorCifar, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 2, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 4, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False)
        )

    def forward(self, input):
        return self.main(input).flatten()


class DiscriminatorUB(nn.Module):
    def __init__(self, discriminator):
        super(DiscriminatorUB, self).__init__()
        self.discriminator = discriminator

    def forward(self, input):
        return self.discriminator(input)

    def acceptance_ratio(self, next_sample, prev_sample):
        x = self(next_sample)
        y = self(prev_sample)
        d_xy = torch.sigmoid(x-y)
        d_yx = torch.sigmoid(y-x)
        return d_xy/d_yx

    def d(self, x, y):
        return torch.sigmoid(self(x)-self(y))


class DiscriminatorCCE(DiscriminatorCifar):
    def __init__(self, discriminator):
        super(DiscriminatorCCE, self).__init__()
        self.discriminator = discriminator

    def forward(self, input):
        return self.discriminator(input)

    def acceptance_ratio(self, next_sample, prev_sample):
        d_x = self.d(next_sample)
        d_y = self.d(prev_sample)
        return d_x/(1-d_x)*(1-d_y)/d_y

    def d(self, x):
        return torch.sigmoid(self(x))


class DiscriminatorMCE(DiscriminatorCifar):
    def __init__(self, discriminator):
        super(DiscriminatorMCE, self).__init__()
        self.discriminator = discriminator

    def forward(self, input):
        return self.discriminator(input)

    def acceptance_ratio(self, next_sample, prev_sample):
        x = self(next_sample)
        y = self(prev_sample)
        d_xy = torch.sigmoid(x - y)
        d_yx = torch.sigmoid(y - x)
        return d_xy / d_yx

    def d(self, x, y):
        return torch.sigmoid(self(x)-self(y))
