import os
import sys
file_dir = os.path.dirname(__file__)
sys.path.append(file_dir)

from encoders import CNN_encoder,MLP_encoder, Linear_encoder_causal,MLP_encoder_causal, Inverse_Observation, CNN_encoder_causal
from priors import LRRNN
import torch.nn as nn
import torch
import scipy
import numpy as np


class VAE(nn.Module):
    
    """
    VAE with RNN / dynamical systems prior
    """
    
    def __init__(self, vae_params):
        super(VAE, self).__init__()
        self.dim_x = vae_params['dim_x']
        if 'dim_x_hat' in vae_params:
            self.dim_x_hat = vae_params['dim_x_hat']
        else:
            self.dim_x_hat = vae_params['dim_x']

        if 'dim_u' in vae_params:
            self.dim_u = vae_params['dim_u']
        else:
            self.dim_u = 0

        self.dim_z = vae_params['dim_z']
        self.dim_N = vae_params['dim_N']
        self.vae_params = vae_params
        if vae_params['prior_architecture']=="PLRNN" or "LRRNN":
            self.prior = LRRNN(self.dim_x_hat,self.dim_z,self.dim_u,self.dim_N,vae_params['prior_params'])
        else:
            print("WARNING: prior does not exist")
        print(vae_params['enc_architecture'])
        self.has_encoder = True
        if vae_params['enc_architecture']=="CNN":
            self.encoder = CNN_encoder(self.dim_x,self.dim_z,vae_params['enc_params'])
        elif vae_params['enc_architecture']=="CNN_causal":
            self.encoder = CNN_encoder_causal(self.dim_x,self.dim_z,vae_params['enc_params'])
        elif vae_params['enc_architecture']=='MLP':
            self.encoder = MLP_encoder(self.dim_x,self.dim_z,vae_params['enc_params'])
        elif vae_params['enc_architecture']=='Linear_causal':
            self.encoder =  Linear_encoder_causal(self.dim_x,self.dim_z,vae_params['enc_params'])
        elif vae_params['enc_architecture']=="MLP_causal":
            self.encoder = MLP_encoder_causal(self.dim_x,self.dim_z,vae_params['enc_params'])
        elif vae_params['enc_architecture']=="Inv_Obs":
            self.encoder = Inverse_Observation(self.dim_x,self.dim_z,vae_params['enc_params'],self.prior.inv_observation)
        elif vae_params['enc_architecture']=='None':
            self.has_encoder = False
            self.encoder = lambda x: (torch.zeros_like(x),self.prior.inv_observation(x),torch.zeros_like(x))
        else:
            print("WARNING: encoder does not exist")

        self.min_var = 1e-6
        self.max_var = 100

        self.causal = vae_params['causal']
        self.MSE_loss = nn.MSELoss()

        #for Poisson
        if 'obs_rectify' not in self.vae_params or self.vae_params['obs_rectify']=='exp':
            self.obs_rectify = torch.exp
        elif self.vae_params['obs_rectify'] =='relu':
            self.obs_rectify = lambda x: torch.relu(x)+1e-10
        elif self.vae_params['obs_rectify'] =='softplus':
            self.obs_rectify = torch.nn.functional.softplus

    def Gauss_ll(self, X, mean,var):
        """
        Gaussian log likelihood of x given mean and variance

        Args:
            X (torch.tensor; BS, dim_x, K): input data
            mean (torch.tensor; BS, dim_x, K): mean of the Gaussian
            var (torch.tensor; BS, dim_x, K): variance of the Gaussian
        Returns:
            ll (torch.tensor; BS, K): log likelihood
        """
        ll = -.5*(torch.log(var)+(((mean-X)**2)/var)+np.log(2*np.pi))
        ll = ll.sum(axis=1)
        return ll
    
    def Gauss_ll_full(self, X, mean,var):
        """
        Gaussian log likelihood of x given mean and full covariance

        Args:
            X (torch.tensor; BS, dim_x, dim_x, K): input data
            mean (torch.tensor; BS, dim_x, K): mean of the Gaussian
            var (torch.tensor; dim_x,dim_x): variance of the Gaussian
        Returns:
            ll (torch.tensor; BS, K): log likelihood
        """
        Prec = torch.linalg.inv(var)
        ll = -.5*(torch.log(torch.det(var))+torch.einsum('bzk,zl,blk->bk',X-mean,Prec,X-mean)+np.log(2*np.pi))
        return ll
    def Gauss_ll_full_chol(self, X, mean,chol_var):
        """
        Gaussian log likelihood of x given mean and full covariance

        Args:
            X (torch.tensor; BS, dim_x, dim_x, K): input data
            mean (torch.tensor; BS, dim_x, K): mean of the Gaussian
            var (torch.tensor; dim_x,dim_x): variance of the Gaussian
        Returns:
            ll (torch.tensor; BS, K): log likelihood
        """
        Prec = torch.cholesky_inverse(chol_var)
        log_det = 2*torch.sum(torch.log(torch.diagonal(chol_var)))
        ll = -.5*(log_det+torch.einsum('bzk,zl,blk->bk',X-mean,Prec,X-mean)+X.shape[1]*np.log(2*np.pi))
        return ll
    
    def Poisson_ll(self, X, log_rate,*args):
        """
        Poisson log likelihood of x given (log) rate
        """
        if 'obs_rectify' not in self.vae_params or self.vae_params['obs_rectify']=='exp':
            ll = X*log_rate-torch.exp(log_rate)-torch.lgamma(X+1)
        elif self.vae_params['obs_rectify'] =='relu':
            rate = torch.relu(log_rate)+1e-10
            ll = X*torch.log(rate)-rate-torch.lgamma(X+1)
        elif self.vae_params['obs_rectify'] =='softplus':
            rate = torch.nn.functional.softplus(log_rate)
            ll = X*torch.log(rate)-rate-torch.lgamma(X+1)
        return ll.sum(axis=1)

    
    def ll_pz_analytical(self,var_q, mean_q,var_p,mean_p):
        # Cross Entropy
        ll_z = -.5*torch.sum(torch.log((var_p) *np.pi*2)+((mean_p - mean_q)**2)/(var_p)+((var_q)/(var_p)), axis=1)
        return ll_z
    
    def ll_qz_analytical(self,var_q):
        # Neg Entropy
        neg_entropy= -.5*torch.sum(torch.log((var_q) * np.pi*2)+1,axis=1)
        return neg_entropy

    
    def resample(self,Qz,indices):
        return torch.gather(Qz,2,indices.unsqueeze(1).expand(Qz.shape))

    def sample_indices_systematic(self,log_weight):
        """Sample ancestral index using systematic resampling.
        FROM: https://github.com/tuananhle7/aesmc

        Args:
            log_weight: log of unnormalized weights, tensor
                [batch_size, num_particles]
        Returns:
            ancestral index: LongTensor [batch_size, num_particles]
        """

        if torch.sum(log_weight != log_weight).item() != 0:
            raise FloatingPointError('log_weight contains nan element(s)')

        batch_size, num_particles = log_weight.size()
        indices = torch.zeros(batch_size, num_particles,device = log_weight.device, dtype=torch.long)
        log_weight = log_weight.to(dtype=torch.double).detach()
        uniforms = torch.rand(size=[batch_size, 1], device=log_weight.device, dtype=log_weight.dtype)
        pos = (uniforms + torch.arange(0, num_particles,device=log_weight.device)) / num_particles

        normalized_weights = torch.exp(log_weight - torch.logsumexp(log_weight, axis=1, keepdims=True))

        cumulative_weights = torch.cumsum(normalized_weights, axis=1)
        # hack to prevent numerical issues
        max = torch.max(cumulative_weights, axis=1, keepdims=True).values
        cumulative_weights = cumulative_weights / max

        for batch in range(batch_size):
            indices[batch] = torch.bucketize(pos[batch], cumulative_weights[batch])

        return indices    
    
    def sample_indices_multinomial(self,log_w):
        """Sample ancestral index using multinomial resampling.
        Args:
            log_weight: log of unnormalized weights, tensor
                [batch_size, num_particles]
        Returns:
            ancestral index: [batch_size, num_particles]
        """
        k = log_w.shape[1]
        log_w_tilde = log_w - torch.logsumexp(log_w, dim=1, keepdim=True)
        w_tilde = log_w_tilde.exp().detach()#+1e-5
        w_tilde = w_tilde/w_tilde.sum(1, keepdim=True)
        return torch.multinomial(w_tilde,k,replacement=True) #m* numsamples
    
    def norm_and_detach_weights(self,log_w):
        log_weight = log_w.detach()
        log_weight = log_weight - torch.max(log_weight, -1,keepdim=True)[0]#
        reweight = torch.exp(log_weight)
        reweight =reweight / torch.sum(reweight, -1,keepdim=True)
        return reweight

    def forward_Optimal_VGTF(self,x,u=None, k=1,MC_p=True, dreg_p="none", MC_q=True, dreg_q="none", resample=False,out_likelihood="Gauss",bootstrap=False):
        """
        Forward pass of the VAE
        Note, here the approximate posterior is the optimal linear combination of the encoder and the RNN
        This can be calculated using the Kalman filter for linear observations and non-linear latents
        Args:
            x (torch.tensor; n_trials x dim_X x time_steps): input data
            u (torch.tensor; n_trials x dim_U x time_steps): input stim
        Returns:
            Fzs_posterior (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the approximate posterior
            Fzs_encoder (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the encoder
            Fzs_prior (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the prior
            Esigma (torch.tensor; n_trials x dim_z x time_steps): standard deviation of the encoder
            Emean (torch.tensor; n_trials x dim_z x time_steps): mean of the approximate encoder
            Observations (torch.tensor; n_trials x dim_X x time_steps): observations
        
        """
        log_ws = []
        log_ll = []
        ll_xs = []
        ll_pzs = []
        ll_qzs = []
        Qzs = []
        alphas = []
        eff_var_prior = self.prior.full_cov_embed(self.prior.R_z) #Dz,Dz
        eff_var_x_diag = torch.diag(torch.clip(self.prior.var_embed_x(self.prior.R_x),1e-8)) #Dx,Dx
        eff_var_x = torch.clip(self.prior.var_embed_x(self.prior.R_x),1e-8).unsqueeze(0).unsqueeze(-1) #1,Dx,1
        eff_std_x = torch.clip(self.prior.std_embed_x(self.prior.R_x),1e-4)
        eff_var_prior_t0 = self.prior.full_cov_embed(self.prior.R_z_t0) #Dz,Dz
        eff_var_prior_chol = self.prior.chol_cov_embed(self.prior.R_z) #Dz,Dz
        eff_var_prior_t0_chol = self.prior.chol_cov_embed(self.prior.R_z_t0) #Dz,Dz
        batch_size, dim_x, time_steps = x.shape
        #prior_mean = self.prior.get_initial_state(u[:,:,0]).unsqueeze(2).expand(batch_size,self.dim_z,k) #1,Dz,1
        #print(u.shape) #torch.Size([10, 2, 200])
        prior_mean = self.prior.get_initial_state(u[:,:,0]).unsqueeze(2)

        x = x.unsqueeze(-1) # for k

        eps = torch.randn(batch_size,self.dim_z,time_steps,k,device=x.device)
        #print(eff_var_x)
        #print(eff_var_x_diag)

        if self.prior.params['readout_rates']=='currents':
            B = self.prior.observation.cast_B(self.prior.observation.B)@self.prior.transition.m_transform(self.prior.transition.m)
            B = B.T
        else:
            B = self.prior.observation.cast_B(self.prior.observation.B)
        Obs_bias  = self.prior.observation.Bias.squeeze(-1)
        #B_inv = torch.linalg.pinv(B)
        # Run the encoder
        #print(B_inv.shape)
        #print(x.shape)
        Kalman_gain = eff_var_prior_t0@B@torch.linalg.inv(eff_var_x_diag+B.T@eff_var_prior_t0@B)
        alpha =Kalman_gain@B.T
        #mean_Q = (torch.eye(self.dim_z,device=alpha.device)-alpha)@prior_mean + Kalman_gain@x[:,:,0]
        #var_Q = eff_var_prior - alpha@eff_var_prior
        
        one_min_alpha = torch.eye(self.dim_z,device=alpha.device)-alpha
        #mean_Q = one_min_alpha@prior_mean + Kalman_gain@x[:,:,0]

        #var_Q = eff_var_prior - alpha@eff_var_prior
        #Joseph stabilised Covariance
        var_Q = one_min_alpha@eff_var_prior_t0@one_min_alpha.T + Kalman_gain@eff_var_x_diag@Kalman_gain.T
        var_Q = torch.eye(self.dim_z,device=alpha.device)*1e-8+(var_Q+var_Q.T)/2

        
        #mean_Q = torch.einsum("xz, bxk -> bzk",B_inv, x[:,:,0])
        #var_Q = B_inv.T@eff_var_x_diag@B_inv
        #print("t")
        #if torch.sum(torch.isnan(var_Q)):
        #    print(var_Q)
        #    print("var_Q")
        var_Q_cholesky = torch.linalg.cholesky(var_Q)
        mean_Q = torch.einsum('zs,BsK->BzK',one_min_alpha,prior_mean)+torch.einsum('zx,BxK->BzK',Kalman_gain,x[:,:,0]-Obs_bias)

        Q_dist = torch.distributions.MultivariateNormal(loc = mean_Q.permute(0,2,1),scale_tril=var_Q_cholesky)
        #a.log_prob(x.permute(0,2,1))
        Qz = Q_dist.rsample()
        ll_qz= Q_dist.log_prob(Qz)
        #print(ll_qz.shape)#.shape
        pz_dist = torch.distributions.MultivariateNormal(loc = prior_mean.permute(0,2,1), scale_tril = eff_var_prior_t0_chol)
        ll_pz = pz_dist.log_prob(Qz)
        Qz = Qz.permute(0,2,1)
        mean_x = torch.einsum("zx, bzk -> bxk",B, Qz)+Obs_bias

        x_dist = torch.distributions.Normal(loc = mean_x.permute(0,2,1), scale = eff_std_x)
        ll_x = x_dist.log_prob(x[:,:,0].permute(0,2,1)).sum(axis=-1)
        """
        Qz = mean_Q + torch.einsum("xz, BzK-> BxK", var_Q_cholesky,eps[:,:,0])
        mean_x = torch.einsum("zx, bzk -> bxk",B, Qz)
        ll_x = self.Gauss_ll(x[:,:,0], mean_x,eff_var_x)
        ll_pz = self.Gauss_ll_full_chol(Qz, prior_mean,eff_var_prior_t0_chol)
        ll_qz = self.Gauss_ll_full_chol(Qz, mean_Q,var_Q_cholesky)
        """
        #print(torch.norm(ll_x-self.Gauss_ll(x[:,:,0], mean_x,eff_var_x)))
        #print(torch.norm(ll_pz-self.Gauss_ll_full_chol(Qz, prior_mean,eff_var_prior_t0_chol)))
        #print(torch.norm(ll_qz-self.Gauss_ll_full_chol(Qz, mean_Q,var_Q_cholesky)))
    
        log_w = ll_x+ll_pz-ll_qz
        ll_xsum = torch.logsumexp(ll_x.detach(), axis=-1) - np.log(k)
        ll_pzsum = torch.logsumexp(ll_pz.detach(), axis=-1) - np.log(k)
        ll_qzsum = torch.logsumexp(ll_qz.detach(), axis=-1) - np.log(k)
        log_ws.append(torch.logsumexp(log_w, axis=-1) - np.log(k))   
        log_ll.append(torch.logsumexp(log_w.detach(), axis=-1) - np.log(k))

        Qzs.append(Qz)
        ll_xs.append(ll_xsum)
        ll_pzs.append(ll_pzsum)
        ll_qzs.append(ll_qzsum)

        time_steps = x.shape[2]
        u = u.unsqueeze(-1) #account for k
        for t in range(1,time_steps):
            if resample == "multinomial":
                indices = self.sample_indices_multinomial(log_w)
                Qz = self.resample(Qz, indices)
            elif resample == "systematic":
                indices = self.sample_indices_systematic(log_w)
                Qz = self.resample(Qz, indices)
            elif resample == "none":
                pass
            else:
                print("WARNING: resample does not exist")
                print("use, one of: multinomial, systematic, none")



            prior_mean = self.prior(Qz.unsqueeze(2),u=u[:,:,t].unsqueeze(2),noise_scale=0).squeeze(2)

            Kalman_gain = eff_var_prior@B@torch.linalg.inv(eff_var_x_diag+B.T@eff_var_prior@B)
            alpha =Kalman_gain@B.T
            one_min_alpha = torch.eye(self.dim_z,device=alpha.device)-alpha
            #mean_Q = one_min_alpha@prior_mean + Kalman_gain@x[:,:,t]
            mean_Q = torch.einsum('zs,BsK->BzK',one_min_alpha,prior_mean)+torch.einsum('zx,BxK->BzK',Kalman_gain,x[:,:,t]-Obs_bias)
            #var_Q = eff_var_prior - alpha@eff_var_prior
            #Joseph stabilised Covariance
            var_Q = one_min_alpha@eff_var_prior@one_min_alpha.T + Kalman_gain@eff_var_x_diag@Kalman_gain.T
            var_Q = torch.eye(self.dim_z,device=alpha.device)*1e-8+(var_Q+var_Q.T)/2

            var_Q_cholesky = torch.linalg.cholesky(var_Q)

            Q_dist = torch.distributions.MultivariateNormal(loc = mean_Q.permute(0,2,1),scale_tril=var_Q_cholesky)
            Qz = Q_dist.rsample()
            ll_qz= Q_dist.log_prob(Qz)
            pz_dist = torch.distributions.MultivariateNormal(loc = prior_mean.permute(0,2,1), scale_tril = eff_var_prior_chol)
            ll_pz = pz_dist.log_prob(Qz)
            Qz = Qz.permute(0,2,1)
            mean_x = torch.einsum("zx, bzk -> bxk",B, Qz)+Obs_bias
            x_dist = torch.distributions.Normal(loc = mean_x.permute(0,2,1), scale = eff_std_x)
            ll_x = x_dist.log_prob(x[:,:,t].permute(0,2,1)).sum(axis=-1)

            log_w = ll_x+ll_pz-ll_qz
             
            ll_xsum = torch.logsumexp(ll_x.detach(), axis=-1) - np.log(k)
            ll_pzsum = torch.logsumexp(ll_pz.detach(), axis=-1) - np.log(k)
            ll_qzsum = torch.logsumexp(ll_qz.detach(), axis=-1) - np.log(k)
            log_ws.append(torch.logsumexp(log_w, axis=-1) - np.log(k))
            log_ll.append(torch.logsumexp(log_w.detach(), axis=-1) - np.log(k))
            Qzs.append(Qz)
            ll_xs.append(ll_xsum)
            ll_pzs.append(ll_pzsum)
            ll_qzs.append(ll_qzsum)
            alphas.append(alpha)

        log_ws = torch.stack(log_ws)
        log_ll = torch.stack(log_ll)
        log_xs = torch.stack(ll_xs)
        log_pzs = torch.stack(ll_pzs)
        log_qzs = torch.stack(ll_qzs)
        alphas = torch.stack(alphas)
        #log_likelihood = torch.logsumexp(log_ws, axis=-1) - np.log(k)
        #log_likelihood =torch.sum(log_likelihood,axis=0)
        
        log_likelihood =torch.sum(log_ll,axis=0)
        Loss = torch.sum(log_ws,axis=0)
        log_xs =torch.sum(log_xs,axis=0)
        log_pzs =torch.sum(log_pzs,axis=0)
        log_qzs =torch.sum(log_qzs,axis=0)

        log_likelihood/=time_steps
        Loss/=time_steps
        log_xs/=time_steps
        log_pzs/=time_steps
        log_qzs/=time_steps
        
        Qzs=torch.stack(Qzs)
        Qzs = Qzs.permute(1,2,0,3)
        return Loss,Qzs,torch.zeros_like(Qzs), log_xs, log_pzs, -log_qzs, log_likelihood,alphas
    

    def forward_VGTF(self,x,u=None,k=1,resample=False,out_likelihood="Gauss",t_forward=0):
        """
        Forward pass of the VAE
        Note, here the approximate posterior is a linear combination of the encoder and the RNN
        Args:
            x (torch.tensor; n_trials x dim_X x time_steps): input data
        Returns:
            Fzs_posterior (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the approximate posterior
            Fzs_encoder (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the encoder
            Fzs_prior (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the prior
            Esigma (torch.tensor; n_trials x dim_z x time_steps): standard deviation of the encoder
            Emean (torch.tensor; n_trials x dim_z x time_steps): mean of the approximate encoder
            Observations (torch.tensor; n_trials x dim_X x time_steps): observations
        
        """
        
        if out_likelihood=="Gauss":
            ll_x_func = lambda x,mu,sd: torch.distributions.Normal(loc=mu,scale=sd).log_prob(x).sum(axis=1)
        elif out_likelihood=="Poisson":
            ll_x_func = lambda x,mu,sd: torch.distributions.Poisson(self.obs_rectify(mu)).log_prob(x).sum(axis=1)
        else:
            print("WARNING: likelihood does not exist, use one of: Gauss, Poisson")
    
        # Run the encoder
        Esample, Emean,log_Evar, eps_sample = self.encoder(x[:,:self.dim_x,:x.shape[2]-t_forward],k=k) #Bs,Dx,T,K
        
        # Clamp the variances to avoid numerical issues
        Evar = torch.clamp(torch.exp(log_Evar),min=self.min_var,max= self.max_var)

        # Get the effective variances for the prior and decoder
        eff_var_prior = torch.clamp(self.prior.var_embed_z(self.prior.R_z).unsqueeze(0).unsqueeze(-1),min=self.min_var,max= self.max_var)  #1,Dz,1     
        eff_std_prior = torch.clamp(self.prior.std_embed_z(self.prior.R_z).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var))  #1,Dz,1       
        eff_var_prior_t0 = torch.clamp(self.prior.var_embed_z_t0(self.prior.R_z_t0).unsqueeze(0).unsqueeze(-1),min=self.min_var,max= self.max_var)  #1,Dz,1       
        eff_std_prior_t0 = torch.clamp(self.prior.std_embed_z_t0(self.prior.R_z_t0).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var))  #1,Dz,1       
        eff_std_x = torch.clamp(self.prior.std_embed_x(self.prior.R_x).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var)) #1,Dx,1

        # Cut some of the data if a CNN was used without padding
        cl=self.encoder.cut_len
        if cl>0:
            if self.causal:
                x_hat = x[:,:,cl:].unsqueeze(-1)
            else:
                x_hat = x[:,:,cl//2:-cl//2].unsqueeze(-1)
        else:
            x_hat=x.unsqueeze(-1)


        # Initialise some lists 
        bs,dim_z,time_steps,_ = Esample.shape
        #print(bs,dim_x,time_steps)
        log_ws = []
        log_ll = []
        ll_xs = []
        ll_pzs = []
        ll_qzs = []
        Qzs = []
        alphas = []

        # Get the prior mean and observation mean
        prior_mean = self.prior.get_initial_state(u[:,:,0]).unsqueeze(2)
        # get the posterior mean
        precZ = 1/eff_var_prior_t0
        precE = 1/Evar[:,:,0]
        precQ = precZ+precE
        alpha = precE/precQ

        mean_Q = (1-alpha)*prior_mean+alpha*Emean[:,:,0]
        eff_var_Q=1/precQ

        Q_dist = torch.distributions.Normal(loc = mean_Q,scale=torch.sqrt(eff_var_Q))
        #Q_dist = torch.distributions.Normal(loc = Emean[:,:,0],scale=torch.sqrt(Evar[:,:,0]))

        Qz = Q_dist.rsample()
        mean_x= self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)

        # Calculate the log likelihoods
        ll_pz = torch.distributions.Normal(loc = prior_mean,scale = eff_std_prior_t0).log_prob(Qz).sum(axis=1)
        ll_qz = Q_dist.log_prob(Qz).sum(axis=1)
        ll_x = ll_x_func(x_hat[:,:,0], mean_x,eff_std_x)

        log_w = ll_x+ll_pz-ll_qz   

        log_ll.append(torch.logsumexp(log_w.detach(), axis=-1) - np.log(k))
        ll_xs.append(torch.logsumexp(ll_x.detach(), axis=-1) - np.log(k))
        ll_pzs.append(torch.logsumexp(ll_pz.detach(), axis=-1) - np.log(k))
        ll_qzs.append(torch.logsumexp(ll_qz.detach(), axis=-1) - np.log(k))
        alphas.append(alpha)
       
        # Append the log weights and log likelihoods
        log_ws.append(torch.logsumexp(log_w, axis=-1) - np.log(k))

        # Append the log weights and log likelihoods
        Qzs.append(Qz)

        u = u.unsqueeze(-1) #account for k
        # Loop through the time steps
        for t in range(1,time_steps):
 

            if resample == "multinomial":
                indices = self.sample_indices_multinomial(log_w)
                Qz = self.resample(Qz, indices)
            elif resample == "systematic":
                indices = self.sample_indices_systematic(log_w)
                Qz = self.resample(Qz, indices)

            elif resample == "none":
                pass
            else:
                print("WARNING: resample does not exist")
                print("use, one of: multinomial, systematic, none")

            
            # prior and posterior mean
            prior_mean = self.prior(Qz.unsqueeze(-2),noise_scale=0,u=u[:,:,t].unsqueeze(2)).squeeze(-2)
            
            precZ = 1/eff_var_prior
            precE = 1/Evar[:,:,t]
            precQ = precZ+precE
            alpha = precE/precQ
            alphas.append(alpha)

            mean_Q = (1-alpha)*prior_mean+alpha*Emean[:,:,t]
            eff_var_Q=1/precQ
            #Qz = mean_Q + torch.sqrt(eff_var_Q)*eps_sample[:,:,t]
            Q_dist = torch.distributions.Normal(loc=mean_Q,scale=torch.sqrt(eff_var_Q))
            Qz = Q_dist.rsample()
            # Get the observation mean
            mean_x = self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)

            # Calculate the log likelihoods
            ll_pz = torch.distributions.Normal(loc = prior_mean,scale=eff_std_prior).log_prob(Qz).sum(axis=1)
            ll_qz = Q_dist.log_prob(Qz).sum(axis=1)
            ll_x = ll_x_func(x_hat[:,:,t], mean_x,eff_std_x)
            # Calculate the log weights
            log_w = ll_x+ll_pz-ll_qz   
            log_ll.append(torch.logsumexp(log_w.detach(), axis=-1) - np.log(k))
            ll_xs.append(torch.logsumexp(ll_x.detach(), axis=-1) - np.log(k))
            ll_pzs.append(torch.logsumexp(ll_pz.detach(), axis=-1) - np.log(k))
            ll_qzs.append(torch.logsumexp(ll_qz.detach(), axis=-1) - np.log(k))
            log_ws.append(torch.logsumexp(log_w, axis=-1) - np.log(k))
            Qzs.append(Qz)

        # Use Bootstrap samples for the last t_forward steps
        for t in range(time_steps,time_steps+t_forward):
            if resample == "multinomial":
                indices = self.sample_indices_multinomial(log_w)
                Qz = self.resample(Qz, indices)
            elif resample == "systematic":
                indices = self.sample_indices_systematic(log_w)
                Qz = self.resample(Qz, indices)

            elif resample == "none":
                pass
            else:
                print("WARNING: resample does not exist")
                print("use, one of: multinomial, systematic, none")

            Qz = self.prior(Qz.unsqueeze(-2),noise_scale=1).squeeze(-2)
            mean_x = self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)
            ll_pz = torch.distributions.Normal(loc = prior_mean,scale = eff_std_prior).log_prob(Qz).sum(axis=1)
            ll_x = ll_x_func(x_hat[:,:,t], mean_x,eff_std_x)
            log_w = ll_x
            ll_qz = ll_qz
            
            log_ll.append(torch.logsumexp(log_w.detach(), axis=-1) - np.log(k))
            ll_xs.append(torch.logsumexp(ll_x.detach(), axis=-1) - np.log(k))
            ll_pzs.append(torch.logsumexp(ll_pz.detach(), axis=-1) - np.log(k))
            ll_qzs.append(torch.logsumexp(ll_qz.detach(), axis=-1) - np.log(k))
            log_ws.append(torch.logsumexp(log_w, axis=-1) - np.log(k))
            Qzs.append(Qz)


        log_ws = torch.stack(log_ws)
        log_ll = torch.stack(log_ll)
        log_xs = torch.stack(ll_xs)
        log_pzs = torch.stack(ll_pzs)
        log_qzs = torch.stack(ll_qzs)
        alphas = torch.stack(alphas)

        
        log_likelihood =torch.mean(log_ll,axis=0)
        Loss = torch.mean(log_ws,axis=0)
        log_xs =torch.mean(log_xs,axis=0)
        log_pzs =torch.mean(log_pzs,axis=0)
        log_qzs =torch.mean(log_qzs,axis=0)
        
        Qzs=torch.stack(Qzs)
        Qzs = Qzs.permute(1,2,0,3)
        
        return Loss,Qzs,Esample, log_xs, log_pzs, -log_qzs, log_likelihood,alphas
    
    def forward_VMPF(self,x,u=None,k=1,resample=False,out_likelihood="Gauss",t_forward=0):
        """
        Forward pass of the VAE
        Note, here the approximate posterior is a linear combination of the encoder and the RNN
        Args:
            x (torch.tensor; n_trials x dim_X x time_steps): input data
        Returns:
            Fzs_posterior (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the approximate posterior
            Fzs_encoder (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the encoder
            Fzs_prior (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the prior
            Esigma (torch.tensor; n_trials x dim_z x time_steps): standard deviation of the encoder
            Emean (torch.tensor; n_trials x dim_z x time_steps): mean of the approximate encoder
            Observations (torch.tensor; n_trials x dim_X x time_steps): observations
        
        """
        
        if out_likelihood=="Gauss":
            ll_x_func = lambda x,u,sd: torch.distributions.Normal(loc=u,scale=sd).log_prob(x).sum(axis=1)
        elif out_likelihood=="Poisson":
            ll_x_func = lambda x,u,sd: torch.distributions.Poisson(self.obs_rectify(u)).log_prob(x).sum(axis=1)
        else:
            print("WARNING: likelihood does not exist, use one of: Gauss, Poisson")
    
        # Run the encoder
        Esample, Emean,log_Evar, eps_sample = self.encoder(x[:,:self.dim_x,:x.shape[2]-t_forward],k=k) #Bs,Dx,T,K
        
        # Clamp the variances to avoid numerical issues
        Evar = torch.clamp(torch.exp(log_Evar),min=self.min_var,max= self.max_var)

        # Get the effective variances for the prior and decoder
        eff_var_prior = torch.clamp(self.prior.var_embed_z(self.prior.R_z).unsqueeze(0).unsqueeze(-1),min=self.min_var,max= self.max_var)  #1,Dz,1     
        eff_std_prior = torch.clamp(self.prior.std_embed_z(self.prior.R_z).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var))  #1,Dz,1       
        eff_var_prior_t0 = torch.clamp(self.prior.var_embed_z_t0(self.prior.R_z_t0).unsqueeze(0).unsqueeze(-1),min=self.min_var,max= self.max_var)  #1,Dz,1       
        eff_std_prior_t0 = torch.clamp(self.prior.std_embed_z_t0(self.prior.R_z_t0).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var))  #1,Dz,1       
        eff_std_x = torch.clamp(self.prior.std_embed_x(self.prior.R_x).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var)) #1,Dx,1

        # Cut some of the data if a CNN was used without padding
        cl=self.encoder.cut_len
        if cl>0:
            if self.causal:
                x_hat = x[:,:,cl:].unsqueeze(-1)
            else:
                x_hat = x[:,:,cl//2:-cl//2].unsqueeze(-1)
        else:
            x_hat=x.unsqueeze(-1)


        # Initialise some lists 
        bs,dim_z,time_steps,_ = Esample.shape
        #print(bs,dim_x,time_steps)
        log_ws = []
        log_ll = []
        ll_xs = []
        ll_pzs = []
        ll_qzs = []
        Qzs = []
        alphas = []

        # Get the prior mean and observation mean
        prior_mean = self.prior.get_initial_state().unsqueeze(0).unsqueeze(2)

        # get the posterior mean
        precZ = 1/eff_var_prior_t0
        precE = 1/Evar[:,:,0]
        precQ = precZ+precE
        alpha = precE/precQ

        mean_Q = (1-alpha)*prior_mean+alpha*Emean[:,:,0]
        eff_var_Q=1/precQ

        Q_dist = torch.distributions.Normal(loc = mean_Q,scale=torch.sqrt(eff_var_Q))
        #Q_dist = torch.distributions.Normal(loc = Emean[:,:,0],scale=torch.sqrt(Evar[:,:,0]))

        Qz = Q_dist.rsample()
        mean_x= self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)

        # Calculate the log likelihoods
        ll_pz = torch.distributions.Normal(loc = prior_mean,scale = eff_std_prior_t0).log_prob(Qz).sum(axis=1)
        ll_qz = Q_dist.log_prob(Qz).sum(axis=1)
        ll_x = ll_x_func(x_hat[:,:,0], mean_x,eff_std_x)

        log_w = ll_x+ll_pz-ll_qz   

        log_ll.append(torch.logsumexp(log_w.detach(), axis=-1) - np.log(k))
        ll_xs.append(torch.logsumexp(ll_x.detach(), axis=-1) - np.log(k))
        ll_pzs.append(torch.logsumexp(ll_pz.detach(), axis=-1) - np.log(k))
        ll_qzs.append(torch.logsumexp(ll_qz.detach(), axis=-1) - np.log(k))
        alphas.append(alpha)
       
        # Append the log weights and log likelihoods
        log_ws.append(torch.logsumexp(log_w, axis=-1) - np.log(k))

        # Append the log weights and log likelihoods
        Qzs.append(Qz)

        # Loop through the time steps
        for t in range(1,time_steps):
 

            if resample == "multinomial":
                indices = self.sample_indices_multinomial(log_w)
                Qz = self.resample(Qz, indices)
            elif resample == "systematic":
                indices = self.sample_indices_systematic(log_w)
                Qz = self.resample(Qz, indices)

            elif resample == "none":
                pass
            else:
                print("WARNING: resample does not exist")
                print("use, one of: multinomial, systematic, none")

            
            # prior and posterior mean
            prior_mean = self.prior(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)
            precZ = 1/eff_var_prior
            precE = 1/Evar[:,:,t]
            precQ = precZ+precE
            alpha = precE/precQ
            alphas.append(alpha)

            mean_Q = (1-alpha)*prior_mean+alpha*Emean[:,:,t]
            eff_var_Q=1/precQ
            #Qz = mean_Q + torch.sqrt(eff_var_Q)*eps_sample[:,:,t]
            Q_dist = torch.distributions.Normal(loc=mean_Q,scale=torch.sqrt(eff_var_Q))
            Qz = Q_dist.rsample()
            # Get the observation mean
            mean_x = self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)
            probs_ij = (torch.distributions.Normal(loc = prior_mean.unsqueeze(3),
                                                                scale=eff_std_prior.unsqueeze(3)).log_prob(Qz.unsqueeze(2)).sum(axis=1))
            Qprobs_ij = (torch.distributions.Normal(loc =mean_Q.unsqueeze(3),
                                                                scale=torch.sqrt(eff_var_Q).unsqueeze(3)).log_prob(Qz.unsqueeze(2)).sum(axis=1))

            # Calculate the log likelihoods
            ll_pz = torch.distributions.Normal(loc = prior_mean,scale=eff_std_prior).log_prob(Qz).sum(axis=1)
            ll_qz = Q_dist.log_prob(Qz).sum(axis=1)
            ll_x = ll_x_func(x_hat[:,:,t], mean_x,eff_std_x)

            nom = torch.logsumexp(log_w.unsqueeze(1)+probs_ij,axis=-1)
            #nom0 = []
            #for j in range(k):
                #BS x K
            #    nom0.append(log_w[:,j].unsqueeze(-1)+torch.distributions.Normal(loc = prior_mean,scale=eff_std_prior).log_prob(Qz[:,:,j].unsqueeze(-1)).sum(axis=1))
            #print(torch.stack(nom0).shape)#K x BS x K
            #print(nom.shape)# BS x K
            #print(torch.norm(torch.logsumexp(torch.stack(nom0),axis=0)-nom))
            denom = torch.logsumexp(log_w.unsqueeze(1) + Qprobs_ij,axis=-1)
            # Calculate the log weights
            log_w = nom+ll_x-denom#ll_x+ll_pz-ll_qz   
            log_ll.append(torch.logsumexp(log_w.detach(), axis=-1) - np.log(k))
            ll_xs.append(torch.logsumexp(ll_x.detach(), axis=-1) - np.log(k))
            ll_pzs.append(torch.logsumexp(ll_pz.detach(), axis=-1) - np.log(k))
            ll_qzs.append(torch.logsumexp(ll_qz.detach(), axis=-1) - np.log(k))
            log_ws.append(torch.logsumexp(log_w, axis=-1) - np.log(k))
            Qzs.append(Qz)

        # Use Bootstrap samples for the last t_forward steps
        for t in range(time_steps,time_steps+t_forward):
            if resample == "multinomial":
                indices = self.sample_indices_multinomial(log_w)
                Qz = self.resample(Qz, indices)
            elif resample == "systematic":
                indices = self.sample_indices_systematic(log_w)
                Qz = self.resample(Qz, indices)

            elif resample == "none":
                pass
            else:
                print("WARNING: resample does not exist")
                print("use, one of: multinomial, systematic, none")

            Qz = self.prior(Qz.unsqueeze(-2),noise_scale=1).squeeze(-2)
            mean_x = self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)
            ll_pz = torch.distributions.Normal(loc = prior_mean,scale = eff_std_prior).log_prob(Qz).sum(axis=1)
            ll_x = ll_x_func(x_hat[:,:,t], mean_x,eff_std_x)
            log_w = ll_x
            ll_qz = ll_qz
            
            log_ll.append(torch.logsumexp(log_w.detach(), axis=-1) - np.log(k))
            ll_xs.append(torch.logsumexp(ll_x.detach(), axis=-1) - np.log(k))
            ll_pzs.append(torch.logsumexp(ll_pz.detach(), axis=-1) - np.log(k))
            ll_qzs.append(torch.logsumexp(ll_qz.detach(), axis=-1) - np.log(k))
            log_ws.append(torch.logsumexp(log_w, axis=-1) - np.log(k))
            Qzs.append(Qz)


        log_ws = torch.stack(log_ws)
        log_ll = torch.stack(log_ll)
        log_xs = torch.stack(ll_xs)
        log_pzs = torch.stack(ll_pzs)
        log_qzs = torch.stack(ll_qzs)
        alphas = torch.stack(alphas)

        
        log_likelihood =torch.mean(log_ll,axis=0)
        Loss = torch.mean(log_ws,axis=0)
        log_xs =torch.mean(log_xs,axis=0)
        log_pzs =torch.mean(log_pzs,axis=0)
        log_qzs =torch.mean(log_qzs,axis=0)
        
        Qzs=torch.stack(Qzs)
        Qzs = Qzs.permute(1,2,0,3)
        
        return Loss,Qzs,Esample, log_xs, log_pzs, -log_qzs, log_likelihood,alphas
   
    def forward_VGTF_dreg(self,x,u=None,k=1,MC_p=True, dreg_p="none", MC_q=True, dreg_q="none", resample=False,out_likelihood="Gauss",bootstrap=False):
        """
        Forward pass of the VAE
        Note, here the approximate posterior is a linear combination of the encoder and the RNN
        Args:
            x (torch.tensor; n_trials x dim_X x time_steps): input data
        Returns:
            Fzs_posterior (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the approximate posterior
            Fzs_encoder (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the encoder
            Fzs_prior (torch.tensor; n_trials x dim_z x time_steps): latent time series as predicted by the prior
            Esigma (torch.tensor; n_trials x dim_z x time_steps): standard deviation of the encoder
            Emean (torch.tensor; n_trials x dim_z x time_steps): mean of the approximate encoder
            Observations (torch.tensor; n_trials x dim_X x time_steps): observations
        
        """

        if out_likelihood=="Gauss":
            ll_x_func = lambda x,u,sd: torch.distributions.Normal(loc=u,scale=sd).log_prob(x).sum(axis=1)
        elif out_likelihood=="Poisson":
            ll_x_func = lambda x,u,sd: torch.distributions.Poisson(self.obs_rectify(u)).log_prob(x).sum(axis=1)
        else:
            print("WARNING: likelihood does not exist, use one of: Gauss, Poisson")
    
        # Run the encoder
        Esample, Emean,log_Evar, eps_sample = self.encoder(x,k=k) #Bs,Dx,T,K
        
        # Clamp the variances to avoid numerical issues
        Evar = torch.clamp(torch.exp(log_Evar),min=self.min_var,max= self.max_var)

        # Get the effective variances for the prior and decoder
        eff_var_prior = torch.clamp(self.prior.var_embed_z(self.prior.R_z).unsqueeze(0).unsqueeze(-1),min=self.min_var,max= self.max_var)  #1,Dz,1     
        eff_var_prior_t0 = torch.clamp(self.prior.var_embed_z_t0(self.prior.R_z_t0).unsqueeze(0).unsqueeze(-1),min=self.min_var,max= self.max_var)  #1,Dz,1       
        eff_std_prior_t0 = torch.clamp(self.prior.std_embed_z_t0(self.prior.R_z_t0).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var))  #1,Dz,1       
        eff_std_prior = torch.clamp(self.prior.std_embed_z(self.prior.R_z).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var))  #1,Dz,1       
        eff_var_x = torch.clamp(self.prior.var_embed_x(self.prior.R_x).unsqueeze(0).unsqueeze(-1),min=self.min_var,max= self.max_var) #1,Dx,1
        eff_std_x = torch.clamp(self.prior.std_embed_x(self.prior.R_x).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var)) #1,Dx,1

        precZ = 1/eff_var_prior_t0
        precE = 1/Evar[:,:,0]
        precQ = precZ+precE
        alpha = precE/precQ

        # Cut some of the data if a CNN was used without padding
        cl=self.encoder.cut_len
        if cl>0:
            if self.causal:
                x_hat = x[:,:,cl:].unsqueeze(-1)
            else:
                x_hat = x[:,:,cl//2:-cl//2].unsqueeze(-1)
        else:
            x_hat=x.unsqueeze(-1)

        # Initialise some lists 
        bs,dim_z,time_steps,_ = Esample.shape
        log_ws = []
        log_ll = []
        ll_xs = []
        ll_pzs = []
        ll_qzs = []
        Qzs = []
        alphas = []

        # Get the prior mean and observation mean
        prior_mean = self.prior.get_initial_state().unsqueeze(0).unsqueeze(2)

        mean_Q = (1-alpha)*prior_mean+alpha*Emean[:,:,0]
        eff_var_Q=1/precQ
        #TO DECIDE
        Q_dist = torch.distributions.Normal(loc = mean_Q,scale=torch.sqrt(eff_var_Q))
        Qz = Q_dist.rsample()
        mean_x= self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)

        # Calculate the log likelihoods
        if MC_p:
            #ll_pz = self.Gauss_ll(Qz, prior_mean,eff_var_prior_t0)
            ll_pz = torch.distributions.Normal(loc = prior_mean,scale = eff_std_prior_t0).log_prob(Qz).sum(axis=1)
        else:
            ll_pz = self.ll_pz_analytical(eff_var_Q,mean_Q, eff_var_prior_t0,prior_mean)

        if MC_q:
            #ll_qz = self.Gauss_ll(Qz, mean_Q,eff_var_Q)
            ll_qz = Q_dist.log_prob(Qz).sum(axis=1)
        else:
            ll_qz = self.ll_qz_analytical(eff_var_Q)
        
        ll_x = ll_x_func(x_hat[:,:,0], mean_x,eff_std_x)

        log_w = ll_x+ll_pz-ll_qz   
        log_ll.append(torch.logsumexp(log_w.detach(), axis=-1) - np.log(k))
        ll_xs.append(torch.logsumexp(ll_x.detach(), axis=-1) - np.log(k))
        ll_pzs.append(torch.logsumexp(ll_pz.detach(), axis=-1) - np.log(k))
        ll_qzs.append(torch.logsumexp(ll_qz.detach(), axis=-1) - np.log(k))
        alphas.append(alpha)
        
        if not dreg_q == "none" or not dreg_p == "none":
            reweight = self.norm_and_detach_weights(log_w)
            # Loss lambda 
            # --------------------
            mean_x_grad_lambda = self.prior.get_observation(Qz.unsqueeze(-2).detach(),noise_scale=0).squeeze(-2)
            ll_x_lambda = ll_x_func(x_hat[:,:,0], mean_x_grad_lambda,eff_var_x)
            Loss_lambda = (reweight*ll_x_lambda).sum(axis=-1)

            # Entropy loss phi, path of Qz (collect all derivatives with respect to Qz and reparemeterise)
            # --------------------
            mean_x_grad_phi = self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0,grad=False).squeeze(-2)
            ll_x_phi = ll_x_func(x_hat[:,:,0], mean_x_grad_phi,eff_var_x.detach())
            ll_prior_phi = self.Gauss_ll(Qz,prior_mean.detach(),eff_var_prior_t0.detach())
            
            #loss 
            if not dreg_q == "none":
                ll_q_phi = self.Gauss_ll(Qz, mean_Q.detach(),eff_var_Q.detach())
                log_w_phi =ll_x_phi+ll_prior_phi-ll_q_phi
                Loss_phi = ((reweight**2)*log_w_phi).sum(axis=-1)
            else:
                ll_q_phi = self.Gauss_ll(Qz, mean_Q,eff_var_Q)
                log_w_phi = ll_x_phi+ll_prior_phi-ll_q_phi
                Loss_phi = ((reweight)*log_w_phi).sum(axis=-1)
            
            if not dreg_p =="none":
                # Cross entropy loss theta
                # --------------------
                # reparameterise the samples from Q as if they came from p
                eps_as_p = (Qz-prior_mean)/eff_std_prior_t0
                Qz_as_p = eps_as_p.detach()*eff_std_prior_t0+prior_mean
                # output
                mean_x_grad_theta = self.prior.get_observation(Qz_as_p.unsqueeze(-2),noise_scale=0,grad=False).squeeze(-2)
                ll_x_theta = ll_x_func(x_hat[:,:,0], mean_x_grad_theta,eff_var_x.detach())
                #prior
                ll_prior_theta = self.Gauss_ll(Qz_as_p,prior_mean.detach(),eff_var_prior_t0.detach())
                # encoder
                ll_q_theta =  self.Gauss_ll(Qz_as_p, mean_Q.detach(),eff_var_Q.detach())
                # loss
                log_w_theta = ll_x_theta + ll_prior_theta - ll_q_theta
                Loss_theta =(reweight*ll_x_theta  - ((reweight**2) * log_w_theta)).sum(axis=-1)
                # Total loss
                # ----------------
            else:
                Loss_theta =(reweight*self.Gauss_ll(Qz.detach(), prior_mean,eff_var_prior_t0)).sum(axis=-1)
            
            log_ws.append((Loss_lambda+Loss_phi+Loss_theta))

        else:
        # Append the log weights and log likelihoods
            log_ws.append(torch.logsumexp(log_w, axis=-1) - np.log(k))

        # Append the log weights and log likelihoods
        Qzs.append(Qz)

        # Loop through the time steps
        Qz_as_p = Qz

        for t in range(1,time_steps):
            precZ = 1/eff_var_prior
            precE = 1/Evar[:,:,t]
            precQ = precZ+precE
            alpha = precE/precQ
            alphas.append(alpha)

            if resample == "multinomial":
                indices = self.sample_indices_multinomial(log_w)
                Qz = self.resample(Qz, indices)
                Qz_as_p = self.resample(Qz_as_p,indices)
            elif resample == "systematic":
                indices = self.sample_indices_systematic(log_w)
                Qz = self.resample(Qz, indices)
                Qz_as_p = self.resample(Qz_as_p,indices)

            elif resample == "none":
                pass
            else:
                print("WARNING: resample does not exist")
                print("use, one of: multinomial, systematic, none")

            Qz_prev=Qz
            
            # prior and posterior mean
            prior_mean = self.prior(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)

            mean_Q = (1-alpha)*prior_mean+alpha*Emean[:,:,t]
            eff_var_Q=1/precQ
            #Qz = mean_Q + torch.sqrt(eff_var_Q)*eps_sample[:,:,t]
            Q_dist = torch.distributions.Normal(loc=mean_Q,scale=torch.sqrt(eff_var_Q))
            Qz = Q_dist.rsample()
            # Get the observation mean
            mean_x = self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)

            # Calculate the log likelihoods
            if MC_p:
                #ll_pz = self.Gauss_ll(Qz, prior_mean,eff_var_prior)
                ll_pz = torch.distributions.Normal(loc = prior_mean,scale=eff_std_prior).log_prob(Qz).sum(axis=1)
            else:
                ll_pz = self.ll_pz_analytical(eff_var_Q,mean_Q, eff_var_prior,prior_mean)

            if MC_q:
                #ll_qz = self.Gauss_ll(Qz, mean_Q,eff_var_Q)
                ll_qz = Q_dist.log_prob(Qz).sum(axis=1)
            else:
                ll_qz = self.ll_qz_analytical(eff_var_Q)
            ll_x = ll_x_func(x_hat[:,:,t], mean_x,eff_std_x)
            # Calculate the log weights
            log_w = ll_x+ll_pz-ll_qz   
            log_ll.append(torch.logsumexp(log_w.detach(), axis=-1) - np.log(k))
            ll_xs.append(torch.logsumexp(ll_x.detach(), axis=-1) - np.log(k))
            ll_pzs.append(torch.logsumexp(ll_pz.detach(), axis=-1) - np.log(k))
            ll_qzs.append(torch.logsumexp(ll_qz.detach(), axis=-1) - np.log(k))
            
            if bootstrap:
                log_ws.append(torch.logsumexp(ll_x, axis=-1) - np.log(k))

            elif not (dreg_q == "none" and dreg_p == "none"):
                reweight = self.norm_and_detach_weights(log_w)

                prior_mean_past_t = self.prior(Qz_prev.unsqueeze(-2),noise_scale=0,grad=False).squeeze(-2)
                mean_Q_past_t =(1-alpha.detach())*prior_mean_past_t+alpha.detach()*Emean[:,:,t].detach()
                Qz_past_t = mean_Q_past_t + torch.sqrt(eff_var_Q).detach()*eps_sample[:,:,t]
                prior_mean_t = self.prior(Qz_prev.unsqueeze(-2).detach(),noise_scale=0).squeeze(-2)
                mean_Q_t = (1-alpha)*prior_mean_t+alpha*Emean[:,:,t]
                Qz_t = mean_Q_t + torch.sqrt(eff_var_Q)*eps_sample[:,:,t]
                #print(torch.norm(prior_mean_past_t-prior_mean)) =0
                #print(torch.norm(prior_mean_t-prior_mean)) =0
                #print(torch.norm(Qz_past_t-Qz))
                #print(torch.norm(Qz_t-Qz))
                # Loss lambda, params of observation model
                # --------------------
                mean_x_grad_lambda = self.prior.get_observation(Qz.unsqueeze(-2).detach(),noise_scale=0).squeeze(-2)
                ll_x_lambda = ll_x_func(x_hat[:,:,t], mean_x_grad_lambda,eff_var_x)  
                Loss_lambda = ((reweight)*(ll_x_lambda)).sum(axis=-1)

                # Entropy loss phi, path of Qz t (collect all derivatives with respect to Qz and reparemeterise)
                # --------------------
                mean_x_grad_phi_t = self.prior.get_observation(Qz_t.unsqueeze(-2),noise_scale=0,grad=False).squeeze(-2)
                ll_x_phi_t = ll_x_func(x_hat[:,:,t], mean_x_grad_phi_t,eff_var_x.detach())
                mean_x_grad_phi_past_t = self.prior.get_observation(Qz_past_t.unsqueeze(-2),noise_scale=0,grad=False).squeeze(-2)
                ll_x_phi_past_t = ll_x_func(x_hat[:,:,t], mean_x_grad_phi_past_t,eff_var_x.detach())
                #print(torch.norm(ll_x_phi_t-ll_x))
                #print(torch.norm(ll_x_phi_past_t-ll_x))

                ll_pz_phi_t = self.Gauss_ll(Qz_t,prior_mean.detach(),eff_var_prior.detach())
                ll_pz_phi_past_t = self.Gauss_ll(Qz_past_t,prior_mean_past_t.detach(),eff_var_prior.detach())
                ll_pz_phi_through_prior = self.Gauss_ll(Qz.detach(),prior_mean_past_t,eff_var_prior.detach()) #x

                #print(torch.norm(ll_pz_phi_t-ll_pz))
                #print(torch.norm(ll_pz_phi_past_t-ll_pz))
                #print(torch.norm(ll_pz_phi_through_prior-ll_pz))

                ll_qz_phi_t = self.Gauss_ll(Qz_t, mean_Q.detach(),eff_var_Q.detach())
                ll_qz_phi_through_params = self.Gauss_ll(Qz.detach(), mean_Q_t,eff_var_Q) #x
                
                #ll_qz_phi_t = self.Gauss_ll(Qz_t, mean_Q_t,eff_var_Q)
                #ll_qz_phi_t = ll_qz_phi_t+ll_qz_phi_through_params
                
                ll_qz_phi_past_t = self.Gauss_ll(Qz_past_t, mean_Q_past_t,eff_var_Q.detach())
                log_w_phi_t = ll_x_phi_t+ll_pz_phi_t-ll_qz_phi_t #x
                log_w_phi_past_t = ll_x_phi_past_t+ll_pz_phi_past_t-ll_qz_phi_past_t #x

                #print(torch.norm(log_w-log_w_phi_t))
                #print(torch.norm(log_w-log_w_phi_past_t))
                #print(torch.norm(ll_qz-ll_qz_phi_through_params))


                #print(reweight)
                if dreg_q=="all" and dreg_p=="all":
                    Loss_phi = ((reweight**2)*(log_w_phi_t+log_w_phi_past_t)).sum(axis=-1)
                elif dreg_q =="all" and (dreg_p=="direct" or dreg_p=="none"):
                    Loss_phi = ((reweight**2)*(log_w_phi_t + log_w_phi_past_t)+ reweight*ll_pz_phi_through_prior).sum(axis=-1)
                elif dreg_q =="direct" and dreg_p=="all":
                    Loss_phi = ((reweight**2)*log_w_phi_t +reweight*(log_w_phi_past_t)).sum(axis=-1)
                elif dreg_q =="direct" and (dreg_p=="direct" or dreg_p=="none"):
                    Loss_phi = ((reweight**2)*log_w_phi_t +reweight*(log_w_phi_past_t+ll_pz_phi_through_prior)).sum(axis=-1)
                    #Loss_phi = (reweight*log_w_phi_t+reweight*(log_w_phi_past_t+ll_pz_phi_through_prior-ll_qz_phi_through_params)).sum(axis=-1)
                    #Loss_phi = (reweight*(log_w_phi_t+log_w_phi_past_t+ll_pz_phi_through_prior-ll_qz_phi_through_params)).sum(axis=-1)
                elif dreg_q=="none" and dreg_p=="all":
                    Loss_phi = (reweight*(log_w_phi_t+log_w_phi_past_t-ll_qz_phi_through_params)).sum(axis=-1)
                elif dreg_q=="none" and dreg_p=="direct":
                    Loss_phi = (reweight*(log_w_phi_t+log_w_phi_past_t+ll_pz_phi_through_prior-ll_qz_phi_through_params)).sum(axis=-1)
                   #Loss_phi = (reweight *log_w).sum(axis=-1)
                #dreg_p="none"
                if not (dreg_p=="none"):
                    # Cross entropy loss theta
                    # --------------------
                    # reparameterise the samples from Q as if they came from p

                    eps_as_p = (Qz-prior_mean)/eff_std_prior
                    if dreg_p=="all":
                        Qz_as_p = eps_as_p.detach()*eff_std_prior+prior_mean
                    else:
                        Qz_as_p = eps_as_p.detach()*eff_std_prior+prior_mean_t

                    # output
                    mean_x_grad_theta = self.prior.get_observation(Qz_as_p.unsqueeze(-2),noise_scale=0,grad=False).squeeze(-2)
                    Loss_ll_x_theta = ll_x_func(x_hat[:,:,t], mean_x_grad_theta,eff_var_x.detach())
                    #prior
                    Loss_ll_prior_theta = self.Gauss_ll(Qz_as_p,prior_mean.detach(),eff_var_prior.detach())
                    # encoder
                    Loss_ll_q_theta =  self.Gauss_ll(Qz_as_p, mean_Q.detach(),eff_var_Q.detach())
                    # loss
                    log_w_theta = Loss_ll_x_theta + Loss_ll_prior_theta - Loss_ll_q_theta
                    Loss_theta =(reweight*Loss_ll_x_theta  - ((reweight**2) * log_w_theta)).sum(axis=-1)

                else:
                    Loss_theta =(reweight*self.Gauss_ll(Qz.detach(), prior_mean_t,eff_var_prior)).sum(axis=-1)

                log_ws.append((Loss_lambda+Loss_phi+Loss_theta))
                #log_ws.append((Loss_phi))

            else:
            # Append the log weights and log likelihoods
                log_ws.append(torch.logsumexp(log_w, axis=-1) - np.log(k))
            Qzs.append(Qz)

        #log_ws.append((reweight*log_w).sum(-1))
        log_ws = torch.stack(log_ws)
        log_ll = torch.stack(log_ll)
        log_xs = torch.stack(ll_xs)
        log_pzs = torch.stack(ll_pzs)
        log_qzs = torch.stack(ll_qzs)
        alphas = torch.stack(alphas)
        #log_likelihood = torch.logsumexp(log_ws, axis=-1) - np.log(k)
        #log_likelihood =torch.sum(log_likelihood,axis=0)
        
        log_likelihood =torch.sum(log_ll,axis=0)
        Loss = torch.sum(log_ws,axis=0)
        log_xs =torch.sum(log_xs,axis=0)
        log_pzs =torch.sum(log_pzs,axis=0)
        log_qzs =torch.sum(log_qzs,axis=0)

        log_likelihood/=time_steps
        Loss/=time_steps
        log_xs/=time_steps
        log_pzs/=time_steps
        log_qzs/=time_steps
        
        Qzs=torch.stack(Qzs)
        Qzs = Qzs.permute(1,2,0,3)
        return Loss,Qzs,Esample, log_xs, log_pzs, -log_qzs, log_likelihood,alphas
    
    def forward_GTF(self,x,u,alpha):
        """Deterministic setting"""
        with torch.no_grad():
            z_hat = self.prior.inv_observation(x).unsqueeze(-1)
        #_, z_hat,_, _ = self.encoder(x,k=1) #Bs,Dx,T,K

        batch_size,d_x,time_steps = x.shape                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
        Fzs =[]
        Fz = z_hat[:,:,0]
        #time_steps=100
        u = u.unsqueeze(-1)
        for t in range(1,time_steps):
            #print(Fz.shape)
            Fz_out = self.prior(Fz.unsqueeze(-2),noise_scale=0,u=u[:,:,t].unsqueeze(-2)).squeeze(-2)
            Fz = (1-alpha)*Fz_out+alpha*z_hat[:,:,t]
            Fzs.append(Fz_out)
            #Fzs.append(Fz_out)
        Fzs = torch.stack(Fzs)
        Fzs = Fzs.permute(1,2,0,3)
        outputs = self.prior.get_observation(Fzs,noise_scale=0)
        distance_x =self.MSE_loss(x[:,:,1:].unsqueeze(-1), outputs)
        Loss = -distance_x
        alpha = torch.ones(1,device = Loss.device)*alpha
        return Loss,Fzs,z_hat, torch.ones(1,device = Loss.device), torch.ones(1,device = Loss.device), torch.ones(1,device = Loss.device), torch.ones(1,device = Loss.device),alpha

        

    def to_device(self,device):
        """Move network between cpu / gpu (cuda)"""
        self.encoder.to(device=device)
        self.encoder.normal.loc =self.encoder.normal.loc.to(device=device)
        self.encoder.normal.scale =self.encoder.normal.scale.to(device=device)
        self.prior.to(device=device)
        self.prior.normal.loc =self.prior.normal.loc.to(device=device)
        self.prior.normal.scale =self.prior.normal.scale.to(device=device)
        self.prior.observation.mask = self.prior.observation.mask.to(device=device)
        self.prior.transition.Wu = self.prior.transition.Wu.to(device=device)
  

    def predict_NLB(self, x,u=None, k=1,t_held_in = 0, t_forward= 0,resample='systematic', marginal_smoothing=True):
        self.eval()
        ll_x_func = lambda x,mu: torch.distributions.Poisson(self.obs_rectify(mu)).log_prob(x).sum(axis=1)

        # Run the encoder
        Esample, Emean,log_Evar, eps_sample = self.encoder(x[:,:self.dim_x],k=k) #Bs,Dx,T,K
        x_hat = x.unsqueeze(-1)

        # Clamp the variances to avoid numerical issues
        Evar = torch.clamp(torch.exp(log_Evar),min=self.min_var,max= self.max_var)
        # Get the effective variances for the prior and decoder
        eff_var_prior = torch.clamp(self.prior.var_embed_z(self.prior.R_z).unsqueeze(0).unsqueeze(-1),min=self.min_var,max= self.max_var)  #1,Dz,1     
        eff_std_prior = torch.clamp(self.prior.std_embed_z(self.prior.R_z).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var))  #1,Dz,1       
        eff_var_prior_t0 = torch.clamp(self.prior.var_embed_z_t0(self.prior.R_z_t0).unsqueeze(0).unsqueeze(-1),min=self.min_var,max= self.max_var)  #1,Dz,1       
        eff_std_prior_t0 = torch.clamp(self.prior.std_embed_z_t0(self.prior.R_z_t0).unsqueeze(0).unsqueeze(-1),min=np.sqrt(self.min_var),max= np.sqrt(self.max_var))  #1,Dz,1       

        # Initialise some lists 
        bs,dim_z,time_steps,_ = Esample.shape
        log_ws = []
        Qzs = []
        Qzs_filt = []

        # Get the prior mean and observation mean
        if u is None:
            u= torch.zeros(x.shape[0],self.dim_u,x.shape[2]).to(x.device)
        prior_mean = self.prior.get_initial_state(u[:,:,0]).unsqueeze(2)
        u = u.unsqueeze(-1)
        # get the posterior mean
        precZ = 1/eff_var_prior_t0
        precE = 1/Evar[:,:,0]
        precQ = precZ+precE
        alpha = precE/precQ

        mean_Q = (1-alpha)*prior_mean+alpha*Emean[:,:,0]
        eff_var_Q=1/precQ

        Q_dist = torch.distributions.Normal(loc = mean_Q,scale=torch.sqrt(eff_var_Q))
        #Q_dist = torch.distributions.Normal(loc=Emean[:,:,0],scale = torch.sqrt(Evar[:,:,0]))
        Qz = Q_dist.rsample()
        mean_x= self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)

        ll_pz = torch.distributions.Normal(loc = prior_mean,scale = eff_std_prior_t0).log_prob(Qz).sum(axis=1)
        ll_qz = Q_dist.log_prob(Qz).sum(axis=1)
        ll_x = ll_x_func(x_hat[:,:,0], mean_x[:,:x_hat.shape[1]])

        log_w = ll_x+ll_pz-ll_qz   
        
        # Append the log weights and log likelihoods
        log_ws.append(log_w)
        Qzs.append(Qz)
        
        # Loop through the time steps
        for t in range(1,t_held_in):

            if resample == "multinomial":
                indices = self.sample_indices_multinomial(log_w)
                Qz = self.resample(Qz, indices)
                Qzs_filt.append(Qz)
            elif resample == "systematic":
                indices = self.sample_indices_systematic(log_w)
                Qz = self.resample(Qz, indices)
                Qzs_filt.append(Qz)
            elif resample == "none":
                pass
            else:
                print("WARNING: resample does not exist")
                print("use, one of: multinomial, systematic, none")
    
            # prior and posterior mean
            prior_mean = self.prior(Qz.unsqueeze(-2),noise_scale=0,u=u[:,:,t].unsqueeze(2)).squeeze(-2)
            
            precZ = 1/eff_var_prior
            precE = 1/Evar[:,:,t]
            precQ = precZ+precE
            alpha = precE/precQ

            mean_Q = (1-alpha)*prior_mean+alpha*Emean[:,:,t]
            eff_var_Q=1/precQ

            Q_dist = torch.distributions.Normal(loc=mean_Q,scale=torch.sqrt(eff_var_Q))
            
            Qz = Q_dist.rsample()
            # Get the observation mean
            mean_x = self.prior.get_observation(Qz.unsqueeze(-2),noise_scale=0).squeeze(-2)

            # Calculate the log likelihoods
            ll_pz = torch.distributions.Normal(loc = prior_mean,scale=eff_std_prior).log_prob(Qz).sum(axis=1)
            ll_qz = Q_dist.log_prob(Qz).sum(axis=1)
            ll_x = ll_x_func(x_hat[:,:,t], mean_x[:,:x_hat.shape[1]])

            # Calculate the log weights
            log_w = ll_x+ll_pz-ll_qz   
            log_ws.append(log_w)
            Qzs.append(Qz)

        # resample last time steps
        if resample == "multinomial":
                indices = self.sample_indices_multinomial(log_w)
                Qz = self.resample(Qz, indices)
                Qzs_filt.append(Qz)
                
        elif resample == "systematic":
                indices = self.sample_indices_systematic(log_w)
                Qz = self.resample(Qz, indices)
                Qzs_filt.append(Qz)
        
        
        # Backward Smoothing
        Qzs_sm = torch.zeros_like(torch.stack(Qzs))
        Qzs_sm[-1]= Qzs_filt[-1]

        # Start from the end and move backwards
        log_weights_backward = log_ws[-1]
        for t in range(t_held_in- 2, -1, -1):

            if marginal_smoothing:
                # Marginal smoothing, note this jas K^2 cost!

                prior_mean =  self.prior(Qzs[t].unsqueeze(-2),noise_scale=0,u=u[:,:,t].unsqueeze(2)).squeeze(-2)
                probs_ij = (torch.distributions.Normal(loc = prior_mean.unsqueeze(3),
                                                                scale=eff_std_prior.unsqueeze(3)).log_prob(Qzs_sm[t+1].unsqueeze(2)).sum(axis=1))
                log_denom = torch.logsumexp(log_ws[t].unsqueeze(-1)*probs_ij[:,:],axis=1)
                
                log_weight  = log_ws[t]
                for i in range(k):
                    log_nom_i = probs_ij[:,i]
                    reweight_i =torch.logsumexp(log_weights_backward+log_nom_i-log_denom,axis = 1)
                    log_weight[:,i] += reweight_i
                indices = self.sample_indices_systematic(log_weight)
                Qz = self.resample(Qzs[t], indices)
                Qzs_sm[t] = Qz
            else:
                # Conditional smoothing
                prior_mean =  self.prior(Qzs[t].unsqueeze(-2),noise_scale=0,u=u[:,:,t].unsqueeze(2)).squeeze(-2)
                ll_pz = torch.distributions.Normal(loc = prior_mean,scale=eff_std_prior).log_prob(Qzs_sm[t+1]).sum(axis=1)
                log_weights_reweighted = log_ws[t] + ll_pz
                
                # Resample based on the backward weights
                indices = self.sample_indices_systematic(log_weights_reweighted)
                Qz = self.resample(Qzs[t], indices)
                Qzs_sm[t] = Qz

            
        # Use Bootstrap samples for the last n_forward steps
        for t in range(t_held_in,t_held_in+t_forward):
            Qz = self.prior(Qz.unsqueeze(-2),noise_scale=1).squeeze(-2)
            Qzm = torch.mean(Qz,axis=-1)
            # take mean as most likely... not neccerily the most likely, as trajectories could diverge
            mean_x = self.prior.get_observation(Qzm.unsqueeze(-1).unsqueeze(-1),noise_scale=0).squeeze(-1).squeeze(-1)
            Qzs.append(Qz)

        Qzs_filt=torch.stack(Qzs_filt).permute(1,2,0,3)
        Qzs_sm = Qzs_sm.permute(1,2,0,3)
        
        Xs_filt = self.obs_rectify(self.prior.get_observation(Qzs_filt,noise_scale=0))
        Xs_sm = self.obs_rectify(self.prior.get_observation(Qzs_sm,noise_scale=0))

        return Qzs_filt, Qzs_sm, Xs_filt, Xs_sm
