import torch
from torch import nn
from .layers import MLPBlock
from .ArchitectureSampler import SampleNetworkArchitecture


class AdaptiveEncoder(nn.Module):
    def __init__(self, input_dim, num_neurons=400, z_dim= 20, a_prior=2.0, b_prior=2.0, num_samples=5, truncation=40, device=torch.device("cuda:0")):
        super(AdaptiveEncoder, self).__init__()
        self.mode = "NN"
        self.input_dim = input_dim
        self.num_neurons = num_neurons
        self.truncation = truncation
        self.num_samples = num_samples
        self.device = device

        #instance of a stick breaking process
        self.architecture_sampler = SampleNetworkArchitecture(num_neurons=num_neurons,
                                                              a_prior=a_prior,
                                                              b_prior=b_prior,
                                                              num_samples=num_samples,
                                                              truncation=truncation,
                                                              device=self.device)

        #stack of MLP layers upto the given truncation level
        #MLPBock does the masking and residual connection in it
        self.layers = nn.ModuleList([MLPBlock(self.input_dim, self.num_neurons).to(self.device)])
        for i in range(1, self.truncation):
            self.layers.append(MLPBlock(self.num_neurons, self.num_neurons, residual=True).to(self.device))

        #variational distribution parameters computation layers
        self.enc_mu = nn.Linear(self.num_neurons, z_dim)
        self.enc_sig = nn.Linear(self.num_neurons, z_dim)

    def _forward(self, x, mask_matrix, threshold):
        if not self.training and threshold > len(self.layers):
            threshold = len(self.layers)

        #forward through the stack of MLP layers
        x = x.expand(mask_matrix.shape[0], -1, -1)
        for layer_idx in range(threshold):
            mask = mask_matrix[:, :, layer_idx].unsqueeze(1)
            x = self.layers[layer_idx](x, mask)

        return x

    def forward(self, x, num_samples=5):
        """
        Fits the data with different samples of architectures

        Parameters
        ----------
        x : data
        num_samples : Number of architectures to sample for KL divergence

        Returns
        -------
        act_vec : Tensor
            output from different architectures
        kl_loss: Tensor
            Kl divergence for each sampled architecture
        thresholds: numpy array
            threshold sampled for different architectures
        """
        # sample architecture from beta-bernoulli process
        mask_matrix, pi, n_layers, _ = self.architecture_sampler(num_samples)

        act_vec = self._forward(x, mask_matrix, n_layers)

        # mean and std parameters of variational distribution
        mu, _std = self.enc_mu(act_vec), self.enc_sig(act_vec)
        return mu, _std

    def get_kl(self):
        return self.architecture_sampler.get_kl()


class AdaptiveDecoder(nn.Module):
    def __init__(self, output_dim, z_dim, num_neurons=400, a_prior=2.0, b_prior=2.0, num_samples=5, truncation=40, device=torch.device("cuda:0")):
        super(AdaptiveDecoder, self).__init__()
        self.output_dim = output_dim
        self.num_neurons = num_neurons
        self.device = device
        self.truncation = truncation
        self.num_samples = num_samples

        self.architecture_sampler = SampleNetworkArchitecture(num_neurons=num_neurons,
                                                              a_prior=a_prior,
                                                              b_prior=b_prior,
                                                              num_samples=num_samples,
                                                              truncation=truncation,
                                                              device=self.device)

        self.layers = nn.ModuleList([MLPBlock(z_dim, self.num_neurons).to(self.device)])
        for i in range(1, truncation):
            self.layers.append(MLPBlock(self.num_neurons, self.num_neurons, residual=True).to(self.device))

        self.output_layer = nn.Linear(num_neurons, output_dim)
        self.act_fn = nn.Tanh()

    def forward(self, x, num_samples=5):
        mask_matrix, pi, n_layers, _ = self.architecture_sampler(num_samples)
        threshold = n_layers
        if threshold > len(self.layers):
            threshold = len(self.layers)

        for layer_idx in range(threshold):
            mask = mask_matrix[:, :, layer_idx].unsqueeze(1)
            x = self.layers[layer_idx](x, mask)

        x = self.output_layer(x)
        return x

    def get_kl(self):
        return self.architecture_sampler.get_kl()
