import torch
import torch.nn as nn
from torch import log as tlog
from torch.utils import data


class ZDataset(data.Dataset):
    def __init__(self, path, max_batch=196):
        self.list_ID = range(max_batch)
        self.path = path

    def __len__(self):
        return len(self.list_ID)

    def __getitem__(self, idx):
        id = self.list_ID[idx]
        x = torch.load(self.path + '/Z_' + str(id) + '.pth')
        img_tensor, z_tensor = x[1], x[0]['z']
        return img_tensor, z_tensor


class HmcTranstion(nn.Module):
    def __init__(self, t, dim=100, ):
        super(HmcTranstion, self).__init__()
        self.t = torch.tensor(t)
        self.dim = dim

    def step(self, x_s):
        p = torch.randn(x_s.size(0), self.dim)
        x_new = x_s * torch.cos(self.t) + p * torch.sin(self.t)
        return x_new


class Discriminator(nn.Module):
    def __init__(self, nc=6, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 2, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 4, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 8, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)


class LogitDiscriminator(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super(LogitDiscriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 2, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 4, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 8, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        )

    def forward(self, input):
        return self.main(input)


class MarkovWrap(nn.Module):
    def __init__(self, G, z_transition, device, dim_z=100):
        super(MarkovWrap, self).__init__()
        self.G = G.eval().to(device)
        self.dim_z = dim_z
        self.z_transition = z_transition
        self.device = device

    def forward(self, z_s):
        z_p = self.z_transition.step(z_s)
        x_p = self.G(z_p.view(z_s.size(0), self.dim_z, 1, 1).to(self.device))
        return x_p, z_p


class DWrapperRelativeMCE(nn.Module):
    def __init__(self, D):
        super(DWrapperRelativeMCE, self).__init__()
        self.D = D

    def logit(self, x):
        return self.D(x)

    def ar(self, x_s, x_p):
        d_s = self.logit(x_s)
        d_p = self.logit(x_p)
        d_ps = torch.sigmoid(d_p - d_s)
        d_sp = torch.sigmoid(d_s - d_p)
        return torch.clamp(d_ps / d_sp, 0., 1.)

    def log_test(self, x_s, x_p):
        d_s = self.logit(x_s)
        d_p = self.logit(x_p)
        d_ps = torch.sigmoid(d_p - d_s)
        d_sp = torch.sigmoid(d_s - d_p)
        log_test = torch.log(d_ps) - torch.log(d_sp)
        return log_test

    def forward(self, x_s, x_p):
        return torch.sigmoid(self.logit(x_s) - self.logit(x_p))


class DWrapper(nn.Module):
    def __init__(self, D):
        super(DWrapper, self).__init__()
        self.D = D

    def forward(self, x, y):
        return self.D(torch.cat([x, y], dim=1))

    def log_test(self, x_s, x_p):
        b = 1e-12
        d_sp = (1. - b) * self.forward(x_s, x_p) + b
        d_ps = (1. - b) * self.forward(x_p, x_s) + b
        log_test = tlog(d_ps) - tlog(d_sp)
        return log_test

    def log_ar(self, x_s, x_p):
        log_ar = torch.clamp(self.log_test(x_s, x_p), max=0.0)
        return log_ar

    def ar(self, x_s, x_p):
        log_ar = self.log_ar(x_s, x_p)
        ar = torch.exp(log_ar)
        return ar


