import numpy as np
import torch
from torch import nn
from torch.distributions.bernoulli import Bernoulli
from torch.distributions.normal import Normal
from .AdaptiveMLP import AdaptiveEncoder, AdaptiveDecoder
from .vae_base import VAE


class MLPVAE(VAE):
    def __init__(self, args, device, img_shape, h_dim, z_dim, truncation=25):
        super().__init__(device, z_dim, args)
        x_dim = np.prod(img_shape)
        self.img_shape = img_shape
        self.proc_data = lambda x: x.to(device).reshape(-1, x_dim)

        #initializing encoding and decoding networks
        self.encoder = AdaptiveEncoder(x_dim, h_dim, z_dim, truncation=truncation, device=device)
        self.decoder = AdaptiveDecoder(x_dim, z_dim, num_neurons=h_dim, truncation=truncation, device=device)

    def init(self, module):
        if type(module) == nn.Linear:
            torch.nn.init.xavier_uniform_(
                module.weight, gain=nn.init.calculate_gain("tanh")
            )
            module.bias.data.fill_(0.01)

    #encoder
    def encode(self, x, num_arch_samples=5):
        x = self.proc_data(x)
        mu, _std = self.encoder(x, num_arch_samples)
        return Normal(mu, nn.functional.softplus(_std))

    #decoder
    def decode(self, z, num_arch_samples):
        x = self.decoder(z, num_arch_samples)
        return Bernoulli(logits=x)

    def lpxz(self, true_x, x_dist):
        return x_dist.log_prob(true_x).sum(-1)

    def get_arch_kl(self):
        return self.encoder.get_kl(), self.decoder.get_kl()