# Code from: https://github.com/ritheshkumar95/pytorch-vqvae

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence

from functions import vq, vq_st

from dst_utils import *
from dst import *
from copy import deepcopy
from sparsemax import Sparsemax

import pdb


def to_scalar(arr):
    if type(arr) == list:
        return [x.item() for x in arr]
    else:
        return arr.item()


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        try:
            nn.init.xavier_uniform_(m.weight.data)
            m.bias.data.fill_(0)
        except AttributeError:
            print("Skipping initialization of ", classname)


class VAE(nn.Module):
    def __init__(self, input_dim, dim, z_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 5, 1, 0),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, z_dim * 2, 3, 1, 0),
            nn.BatchNorm2d(z_dim * 2)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(z_dim, dim, 3, 1, 0),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, dim, 5, 1, 0),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
            nn.Tanh()
        )

        self.apply(weights_init)

    def forward(self, x):
        mu, logvar = self.encoder(x).chunk(2, dim=1)

        q_z_x = Normal(mu, logvar.mul(.5).exp())
        p_z = Normal(torch.zeros_like(mu), torch.ones_like(logvar))
        kl_div = kl_divergence(q_z_x, p_z).sum(1).mean()

        x_tilde = self.decoder(q_z_x.rsample())
        return x_tilde, kl_div


class VQEmbedding(nn.Module):
    def __init__(self, K, D):
        super().__init__()
        self.embedding = nn.Embedding(K, D)
        self.embedding.weight.data.uniform_(-1./K, 1./K)

    def forward(self, z_e_x):
        z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
        latents = vq(z_e_x_, self.embedding.weight)
        return latents

    def straight_through(self, z_e_x):
        z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
        z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())
        z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()

        z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
            dim=0, index=indices)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
        z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()

        return z_q_x, z_q_x_bar


class ResBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 3, 1, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 1),
            nn.BatchNorm2d(dim)
        )

    def forward(self, x):
        return x + self.block(x)


class VectorQuantizedVAE(nn.Module):
    def __init__(self, input_dim, dim, K=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 4, 2, 1),
            ResBlock(dim),
            ResBlock(dim),
        )

        self.codebook = VQEmbedding(K, dim)

        self.decoder = nn.Sequential(
            ResBlock(dim),
            ResBlock(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
            nn.Tanh()
        )

        self.apply(weights_init)

    def encode(self, x):
        z_e_x = self.encoder(x)
        latents = self.codebook(z_e_x)
        return latents

    def decode(self, latents):
        z_q_x = self.codebook.embedding(latents).permute(0, 3, 1, 2)  # (B, D, H, W)
        x_tilde = self.decoder(z_q_x)
        return x_tilde

    def forward(self, x):
        z_e_x = self.encoder(x)
        # print("latents: ", z_e_x.shape)
        z_q_x_st, z_q_x = self.codebook.straight_through(z_e_x)
        x_tilde = self.decoder(z_q_x_st)
        return x_tilde, z_e_x, z_q_x


class GatedActivation(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x, y = x.chunk(2, dim=1)
        return F.tanh(x) * F.sigmoid(y)


class GatedMaskedConv2d(nn.Module):
    def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10):
        super().__init__()
        assert kernel % 2 == 1, print("Kernel size must be odd")
        self.mask_type = mask_type
        self.residual = residual

        self.class_cond_embedding = nn.Embedding(
            n_classes, 2 * dim
        )

        kernel_shp = (kernel // 2 + 1, kernel)  # (ceil(n/2), n)
        padding_shp = (kernel // 2, kernel // 2)
        self.vert_stack = nn.Conv2d(
            dim, dim * 2,
            kernel_shp, 1, padding_shp
        )

        self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)

        kernel_shp = (1, kernel // 2 + 1)
        padding_shp = (0, kernel // 2)
        self.horiz_stack = nn.Conv2d(
            dim, dim * 2,
            kernel_shp, 1, padding_shp
        )

        self.horiz_resid = nn.Conv2d(dim, dim, 1)

        self.gate = GatedActivation()

    def make_causal(self):
        self.vert_stack.weight.data[:, :, -1].zero_()  # Mask final row
        self.horiz_stack.weight.data[:, :, :, -1].zero_()  # Mask final column

    def forward(self, x_v, x_h, h):
        if self.mask_type == 'A':
            self.make_causal()

        h = self.class_cond_embedding(h)
        # print("class conditional embedding", h.size())
        h_vert = self.vert_stack(x_v)
        h_vert = h_vert[:, :, :x_v.size(-1), :]
        out_v = self.gate(h_vert + h[:, :, None, None])

        h_horiz = self.horiz_stack(x_h)
        h_horiz = h_horiz[:, :, :, :x_h.size(-2)]
        v2h = self.vert_to_horiz(h_vert)

        out = self.gate(v2h + h_horiz + h[:, :, None, None])
        if self.residual:
            out_h = self.horiz_resid(out) + x_h
        else:
            out_h = self.horiz_resid(out)

        return out_v, out_h


class GatedPixelCNN(nn.Module):
    def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10):
        super().__init__()
        self.dim = dim
        print("dim: ", dim)

        # Create embedding layer to embed input
        self.embedding = nn.Embedding(input_dim, dim)

        # Building the PixelCNN layer by layer
        self.layers = nn.ModuleList()

        self.n_classes = n_classes

        # Initial block with Mask-A convolution
        # Rest with Mask-B convolutions
        for i in range(n_layers):
            mask_type = 'A' if i == 0 else 'B'
            kernel = 7 if i == 0 else 3
            residual = False if i == 0 else True

            self.layers.append(
                GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes)
            )

        # Add the output layer
        self.output_conv = nn.Sequential(
            nn.Conv2d(dim, 512, 1),
            nn.ReLU(True),
            nn.Conv2d(512, input_dim, 1)
        )

        self.apply(weights_init)

    def forward(self, x, label, horiz=0, vert=0):
        # print("label", label)
        shp = x.size() + (-1, )
        x = self.embedding(x.view(-1)).view(shp)  # (B, H, W, C)
        x = x.permute(0, 3, 1, 2)  # (B, C, W, W)

        x_v, x_h = (x, x)
        for i, layer in enumerate(self.layers):
            x_v, x_h = layer(x_v, x_h, label)
        # pdb.set_trace()

        # input into the last conv layer
        x_linear = self.output_conv[1](self.output_conv[0](x_h))[:,:,horiz,vert]
        return self.output_conv(x_h), x_linear

    def generate(self, label, shape=(8, 8), batch_size=64):
        param = next(self.parameters())
        x = torch.zeros(
            (batch_size, *shape),
            dtype=torch.int64, device=param.device
        )

        for i in range(shape[0]):
            for j in range(shape[1]):
                logits, x_linear = self.forward(x, label, i, j)

                # # Compute linear layer weights
                # weight = self.output_conv[2].weight.data
                # weight_linear = weight.view(weight.shape[0], weight.shape[1])
                # bias_linear = self.output_conv[2].bias.data

                # print("logits", logits.shape)
                probs = F.softmax(logits[:, :, i, j], -1)
                # print("probs", probs.shape)
                x.data[:, i, j].copy_(
                    probs.multinomial(1).squeeze().data
                )
        return x

    def generate_dst(self, label, shape=(8, 8), batch_size=64, full=False):
        param = next(self.parameters())
        x = torch.zeros(
            (batch_size, *shape),
            dtype=torch.int64, device=param.device
        )

        max_num = 0

        for i in range(shape[0]):
            for j in range(shape[1]):
                
                logits, x_linear = self.forward(x, label, i, j)
                # print("x_linear", torch.sum(x_linear[0]), torch.sum(x_linear[1]), torch.sum(x_linear[15]))
                
                probs = F.softmax(logits[:, :, i, j], -1)
                # print("probs before: ", probs)

                # Compute linear layer weights
                weight = torch.transpose(self.output_conv[2].weight.data, 0, 1) # NOTE: We have to transpose to get [J, K]
                weight_linear = weight.view(weight.shape[0], weight.shape[1])
                # print("weight", weight.shape)
                bias_linear = self.output_conv[2].bias.data
                bias_linear = bias_linear.view(bias_linear.shape[0], 1)
                norm_singletons = deepcopy(probs)

                for batch in range(batch_size):
                    features = x_linear[batch].view(1, x_linear[batch].shape[0])
                    # print("logits", logits.shape)

                    dst_obj = DST()
                    # print("linear weights: ", weight_linear.data.cpu().numpy().shape)
                    # print("bias: ", bias_linear.data.cpu().numpy().shape)
                    # print("features: ", features.shape)
                    # IN THIS CASE - ALL the inputs are the same - all one input class!
                    dst_obj.weights_from_linear_layer(weight_linear.data.cpu().numpy(), bias_linear.data.cpu().numpy(),
                        features.data.cpu().numpy(), features.data.cpu().numpy().flatten())
                    dst_obj.get_output_mass(num_classes = x_linear[batch].shape[0], full = full)

                    # print('sum of singletons', sum(dst_obj.output_mass_singletons.flatten()))

                    mask = torch.Tensor((dst_obj.output_mass_singletons == 0.).astype(int).flatten())
                    # pdb.set_trace()
                    # mask_full = mask.repeat(batch_size, 1).byte()
                    norm_singletons[batch, mask.byte()] = 0.
                    # print("amount of prob mass filtered: ", torch.sum(probs[mask_full]))
                # pdb.set_trace()

                norm_singletons = norm_singletons/torch.sum(norm_singletons, 1).view(batch_size, 1)

                num = torch.max(torch.sum(norm_singletons > 0, dim=1))
                if num > max_num:
                    max_num = num

                # print("dst singleton mass: ", np.sum(dst_obj.output_mass_singletons != 0))
                # print("probs after: ", norm_singletons, torch.sum(norm_singletons), torch.sum(norm_singletons > 0))
                # print("x: ", norm_singletons.multinomial(1).squeeze().data.shape)
                # print("probs", probs.shape)
                x.data[:, i, j].copy_(
                    norm_singletons.multinomial(1).squeeze().data
                )
                # pdb.set_trace()
        return x, max_num

    def generate_sparsemax(self, label, shape=(8, 8), batch_size=64, full=False):
        param = next(self.parameters())
        x = torch.zeros(
            (batch_size, *shape),
            dtype=torch.int64, device=param.device
        )

        max_num = 0

        for i in range(shape[0]):
            for j in range(shape[1]):
                
                logits, x_linear = self.forward(x, label, i, j)
                # print("x_linear", torch.sum(x_linear[0]), torch.sum(x_linear[1]), torch.sum(x_linear[15]))
                
                probs = F.softmax(logits[:, :, i, j], -1)
                sparsemax = Sparsemax(dim=-1)
                probs_sparsemax = sparsemax(logits[:, :, i, j])
                # pdb.set_trace()

                num = torch.max(torch.sum(probs_sparsemax > 0, dim=1))
                if num > max_num:
                    max_num = num
                # pdb.set_trace()

                x.data[:, i, j].copy_(
                    probs_sparsemax.multinomial(1).squeeze().data
                )
        return x, max_num
