import numpy as np
import torch
from torch import nn
from torch.distributions.normal import Normal

''' Adopted/Modified from https://github.com/yoonholee/pytorch-vae'''

class VAE(nn.Module):
    def __init__(self, device, z_dim, args):
        super().__init__()
        self.args = args
        self.device = device
        self.train_step = 0
        self.best_loss = np.inf
        self.best_kl = np.inf
        self.prior = Normal(
            torch.zeros([z_dim]).to(device), torch.ones([z_dim]).to(device)
        )

    def proc_data(self, x):
        pass

    def encode(self, x, arch_n):
        pass

    def decode(self, z, mask_matrix, n_layers):
        pass

    def lpxz(self, true_x, x_dist):
        pass

    def get_arch_kl(self):
        pass

    def elbo(self, true_x, approx_posterior_sample, x_dist, approx_posterior):

        true_x = self.proc_data(true_x)
        # data likelihood (reconstruction loss)
        lpxz = self.lpxz(true_x, x_dist)

        # SGVB^A: log p(z) - log q(z|x) + log p(x|z)
        # KL divergence
        lpz = self.prior.log_prob(approx_posterior_sample).sum(-1)
        lqzx = approx_posterior.log_prob(approx_posterior_sample).sum(-1)
        kl = -lpz + lqzx

        return -kl + lpxz

    def logmeanexp(self, inputs, dim=1):
        if inputs.size(dim) == 1:
            return inputs
        else:
            input_max = inputs.max(dim, keepdim=True)[0]
            return (inputs - input_max).exp().mean(dim).log() + input_max

    def forward(self, true_x, arch_n, mean_num, imp_n):

        approx_posterior = self.encode(true_x, arch_n)

        # number of samples per architecture
        L = int((mean_num*imp_n)/arch_n)
        # sample from posterior
        approx_posterior_sample = approx_posterior.rsample((L, )) # L, arch_n, batch_size, z_dim

        # get data distribution from decoder
        decoder_dist = self.decode(approx_posterior_sample, arch_n)  # L, arch_n, batch_size, H*W

        # calculate elbo
        elbo = self.elbo(true_x, approx_posterior_sample, decoder_dist, approx_posterior)  # L, arch_n, batch_size

        # reshape according to the value of M and K
        elbo = torch.reshape(elbo, (imp_n, mean_num, -1))  #imp_n, mean_num, batch_size

        # using multiple importance samples
        elbo_iwae = self.logmeanexp(elbo, 0).squeeze(0)  # mean_num, batch_size

        # averaging IWAE estimators for better SNR of the gradients
        E_elbo_iwae = torch.mean(elbo_iwae, dim=0) # batch_size

        # loss objective
        loss = - torch.mean(E_elbo_iwae, 0) # scalar
        # kl_divergence for architecture variables
        enc_kl_arch, dec_kl_arch = self.get_arch_kl()

        return elbo, loss, enc_kl_arch, dec_kl_arch