from functools import partial
from jax.scipy.optimize import minimize
import jax.numpy as np
from jax import vmap, lax, jit
from jax.tree_util import tree_map, register_pytree_node_class
from tensorflow_probability.substrates import jax as tfp
from jax.config import config
config.update("jax_enable_x64", True)

from ssm.hmm.emissions import Emissions
from ssm.hmm.posterior import StationaryHMMPosterior
import ssm.distributions as ssmd
tfd = tfp.distributions

EPS = 1e-4

@register_pytree_node_class
class SALTEmissions(Emissions):
    def __init__(self,
                 num_states: int,
                 mode: str,
                 l2_penalty: float,
                 output_factors: np.ndarray,            # U: (K, N, D1)
                 input_factors: np.ndarray,             # V: (K, N, D2)
                 lag_factors: np.ndarray,               # W: (K, L, D3)
                 core_tensors: np.ndarray,              # G: (K, D1, D2, D3)
                 biases: np.ndarray,                    # d: (K, N)
                 covariance_matrix_sqrts: np.ndarray,   # Sigma: (K, N, N)
                 ):
        r"""Switching Autoregressive Low-Rank Tensor (SALT) Emissions"""
        super(SALTEmissions, self).__init__(num_states)
        self.mode = mode
        self.l2_penalty = l2_penalty
        
        self.output_factors = output_factors
        self.input_factors = input_factors
        self.lag_factors = lag_factors
        self.core_tensors = core_tensors
        self.biases = biases
        self.covariance_matrix_sqrts = covariance_matrix_sqrts
        
        # precompute tensors
        self.tensors = np.einsum('kdef,kid,kje,klf->kijl',
                                 core_tensors,
                                 output_factors,
                                 input_factors,
                                 lag_factors)
        self.tensors_for_lag_factors = np.einsum('kabc,kia,kjb->kijc',
                                                 core_tensors,
                                                 output_factors,
                                                 input_factors)

    @property
    def emissions_dim(self):
        return self.output_factors.shape[1]

    @property
    def emissions_shape(self):
        return (self.emissions_dim,)

    @property
    def num_lags(self):
        return self.lag_factors.shape[1]
    
    @property
    def core_tensor_dims(self):
        return self.core_tensors.shape[1:]

    def distribution(self, state: int, 
                     covariates=None, 
                     metadata=None, 
                     history: np.ndarray=None) -> tfd.MultivariateNormalTriL:
        """Returns the emissions distribution conditioned on a given state.

        Args:
            state (int): latent state
            covariates (np.ndarray, optional): optional covariates.
                Not yet supported. Defaults to None.

        Returns:
            emissions_distribution (tfd.MultivariateNormalTriL): the emissions distribution
        """
        # multiply history (L, N) with the tensor for given state to get shape (N,) prediction
        mean = np.einsum('ijl,lj->i',
                         self.tensors[state],
                         history)

        mean += self.biases[state]

        return tfd.MultivariateNormalTriL(mean, self.covariance_matrix_sqrts[state])
    
    def log_likelihoods(self, data, covariates=None, metadata=None):
        num_lags = self.num_lags
        num_states = self.num_states
        num_timesteps, emissions_dim = data.shape
        tensors = self.tensors # K, N, N, L
        tensors_reshaped = tensors.transpose(0,1,3,2).reshape((-1, num_lags, emissions_dim))[:,None]
        scale_trils = self.covariance_matrix_sqrts
        biases = self.biases
        
        data_reshaped = data[:-1].reshape(1, 1, num_timesteps-1, emissions_dim)

        mean = lax.conv(data_reshaped, # 1, 1, T-1, N
                        tensors_reshaped, # K*N, 1, L, N 
                        window_strides=(1,1), 
                        padding='VALID') # 1, K*N, T-L, 1
        mean = mean[0,:,:,0].reshape((num_states,
                                      emissions_dim,
                                      num_timesteps - num_lags)) # K, N, T-L
        mean = mean.transpose([2,0,1])
        mean += biases
        
        log_probs = tfd.MultivariateNormalTriL(mean, scale_trils).log_prob(data[num_lags:, None, :])
        log_probs = np.row_stack([np.zeros((num_lags, num_states)), log_probs])
        return log_probs

    def update_core_tensors(self, dataset, Y, conv, Ez, Qinvs, mode):
        if mode == 'cp':
            return self.core_tensors
        
        num_states = self.num_states
        D1, D2, D3 = self.core_tensor_dims
        G, U = self.core_tensors, self.output_factors
        
        def _get_Xhat_for_core_tensors(Uk, Xk):
            Xhat = np.kron(Uk[None,None], Xk[:,:,None]) # (B,T-L,N,D1*D2*D3)
            return Xhat

        X = np.einsum('kje,pkftj->kptef', self.input_factors, conv) # (K, B, T-L, D2, D3)
        X = X.reshape(X.shape[:-2] + (-1,)) # (K, B, T-L, D2*D3)
        Xhat = vmap(_get_Xhat_for_core_tensors)(U, X) # (K,B,T-L,N,D1*D2*D3)
    
        J = np.einsum('ptk,kptni,knm,kptmj->kij', Ez, Xhat, Qinvs, Xhat) # (K,D1*D2*D3,D1*D2*D3)
        J += np.eye(J.shape[-1])[None]*self.l2_penalty
        
        h = np.einsum('ptk,kptni,knm,pktm->ki', Ez, Xhat, Qinvs, Y) # (K,D1*D2*D3)
        core_tensors = np.linalg.solve(J, h) # (K, D1*D2*D3)
        core_tensors = core_tensors.reshape(num_states, D1, D2, D3)
        
        return core_tensors

    def update_output_factors_and_biases(self, dataset, Y, conv, Ez):
        num_lags = self.num_lags
        xhat = np.einsum('kdef,kje,pkftj->pktd',
                        self.core_tensors,
                        self.input_factors,
                        conv) # (B, K, T-L, D1)
        xhat = np.pad(xhat, ((0,0),(0,0),(0,0),(1,0)), constant_values=1) # (B, K, T-L, 1+D1)

        J = np.einsum('ptk,pkti,pktj->kji', Ez, xhat, xhat) # (K, 1+D1, 1+D1)
        J += np.eye(J.shape[-1])[None]*self.l2_penalty
        
        h = np.einsum('ptk,ptn,pkti->kin', Ez, dataset[:,num_lags:], xhat) # (K, 1+D1, N)
        output_factors_and_biases = np.linalg.solve(J, h) # (K, 1+D1, N)
        output_factors_and_biases = np.transpose(output_factors_and_biases, [0,2,1])
        return output_factors_and_biases[:,:,1:], output_factors_and_biases[:,:,0]
    
    def update_input_factors(self, dataset, Y, conv, Ez, Qinvs):
        emissions_dim = self.emissions_dim
        num_states = self.num_states
        D1, D2, D3 = self.core_tensor_dims

        X = np.einsum('kia,kabc,pkctj->pktijb',
                      self.output_factors,
                      self.core_tensors,
                      conv) # (B, K, T, N, N, D2)
        X = X.reshape(X.shape[:-2]+(-1,))  # (B, K, T, N, N*D2)
        
        J = np.einsum('pktab,kac,ptk,pktcd->kbd', X, Qinvs, Ez, X) # (K, N*D2, N*D2)
        J += np.eye(J.shape[-1])[None]*self.l2_penalty
        
        h = np.einsum('pktab,kac,ptk,pktc->kb', X, Qinvs, Ez, Y) # (K, N*D2)
        input_factors = np.linalg.solve(J, h) # (K, N*D2)
        input_factors = input_factors.reshape(num_states, emissions_dim, D2)

        return input_factors
    
    def update_lag_factors(self, dataset, Y, Ez, Qinvs):
        num_batches, num_timesteps, emissions_dim = dataset.shape
        num_lags = self.num_lags
        num_states = self.num_states
        D1, D2, D3 = self.core_tensor_dims
        
        def _get_Xhat_for_lag_factors(t):
            history = lax.dynamic_slice(dataset,
                                        (0,t-num_lags, 0),
                                        (num_batches,num_lags,emissions_dim))

            Xhat = np.einsum('kijc,plj->pkilc',
                             self.tensors_for_lag_factors,
                             history) # (B, K, N, L, D3)
            return Xhat.reshape((num_batches, num_states, emissions_dim, num_lags*D3)) # (B,K,N,L*D3)
        
        Xhat = vmap(_get_Xhat_for_lag_factors)(np.arange(self.num_lags, dataset.shape[1])) # (T-L, B, K, N, L*D3)
        J = np.einsum('ptk,tpknd,kno,tpkoe->kde', Ez, Xhat, Qinvs, Xhat) # (K, L*D3, L*D3)
        J += np.eye(J.shape[-1])[None]*self.l2_penalty
        
        h = np.einsum('ptk,tpknd,kno,pkto->kd', Ez, Xhat, Qinvs, Y) # (K, L*D3)
        lag_factors = np.linalg.solve(J, h) # (K, L*D3)
        
        Yhat = np.einsum('tpkij,kj->pkti', 
                         Xhat, lag_factors) # (B, K, T-L, N)
        Yhat += self.biases[None,:,None]
        lag_factors = lag_factors.reshape(num_states, num_lags, D3) # (K, L, D3)
        
        return lag_factors, Yhat
    
    def update_covariance_matrix_sqrts(self, dataset, Yhat, Ez):
        num_batches, num_timesteps, emissions_dim = dataset.shape
        num_lags = self.num_lags
        num_states = self.num_states
        
        Y = dataset[:,num_lags:].reshape(-1, emissions_dim)
        Yhat_reshaped = np.transpose(Yhat, [1,0,2,3]).reshape(num_states, -1, emissions_dim)
        Ez_reshaped = Ez.reshape(-1, num_states).T
        
        covariance_matrices = vmap(lambda yhatk, Ezk: np.cov(Y - yhatk, 
                                                             rowvar=False, 
                                                             bias=True, 
                                                             aweights=Ezk))(Yhat_reshaped, Ez_reshaped)
        
        covariance_matrices += np.eye(emissions_dim)[None]*EPS
        covariance_matrix_sqrts = np.linalg.cholesky(covariance_matrices)
        
        return covariance_matrix_sqrts

    def convolve_dataset_with_lag_factors(self, dataset):
        num_batches, num_timesteps, emissions_dim = dataset.shape
        num_lags = self.num_lags
        num_states = self.num_states
        D1, D2, D3 = self.core_tensor_dims
        
        lag_factors = np.transpose(self.lag_factors,[0,2,1]) # K, L, D3 -> K, D3, L
        lag_factors = lag_factors.reshape((num_states*D3, 1, num_lags, 1)) # K, D3, L -> K*D3, L

        conv_output = lax.conv(dataset[:,None,:-1], # B, 1, T-1, N
                               lag_factors, # K*D3, 1, L, 1 
                               window_strides=(1,1), 
                               padding='VALID') # B, K*D3, T-L, N
        conv_output = conv_output.reshape((num_batches,
                                           num_states,
                                           D3,
                                           num_timesteps - num_lags,
                                           emissions_dim)) # B, K, D3, T-L, N
        return conv_output
    
    def m_step(self,
               dataset: np.ndarray,
               posterior: StationaryHMMPosterior,
               covariates=None,
               metadata=None):
        r"""Update the distribution with an M step.
        Operates over a batch of data.
        Args:
            dataset (np.ndarray): observed data
                of shape :math:`(\text{batch\_dim}, \text{num\_timesteps}, \text{emissions\_dim})`.
            posteriors (StationaryHMMPosterior): HMM posterior object
                with batch_dim to match dataset.
                
        Returns:
            emissions (SALTEmissions): updated emissions object
        """
        
        num_batches, num_timesteps, emissions_dim = dataset.shape
        num_lags = self.num_lags
        num_states = self.num_states
        D1, D2, D3 = self.core_tensor_dims
        mode = self.mode
        
        Ez = posterior.expected_states[:,num_lags:] # B, T-L, K
        Qs = np.einsum('kab,kcb->kac', 
                self.covariance_matrix_sqrts, 
                self.covariance_matrix_sqrts)
        Qinvs = np.linalg.inv(Qs) # K, N, N
        conv = self.convolve_dataset_with_lag_factors(dataset) # B, K, D3, T-L, N
        
        # update output factors and biases
        self.output_factors, self.biases = self.update_output_factors_and_biases(dataset, None, conv, Ez)
        
        Y = dataset[:,None,num_lags:]-self.biases[None,:,None] # B, K, T-L, N
  
        # update core tensors
        self.core_tensors = self.update_core_tensors(dataset, Y, conv, Ez, Qinvs, mode)

        # update input factors
        self.input_factors = self.update_input_factors(dataset, Y, conv, Ez, Qinvs)
        
        self.tensors_for_lag_factors = np.einsum('kabc,kia,kjb->kijc',
                                                 self.core_tensors,
                                                 self.output_factors,
                                                 self.input_factors)
        
        # update lag factors
        self.lag_factors, Yhat = self.update_lag_factors(dataset, Y, Ez, Qinvs)
        
        # update covariance_matrix_sqrts
        self.covariance_matrix_sqrts = self.update_covariance_matrix_sqrts(dataset, Yhat, Ez)
        
        self.tensors = np.einsum('kdef,kid,kje,klf->kijl',
                                 self.core_tensors,
                                 self.output_factors,
                                 self.input_factors,
                                 self.lag_factors)

        # return updated self
        return self

    def tree_flatten(self):
        children = (self.output_factors,
                    self.input_factors,
                    self.lag_factors,
                    self.core_tensors,
                    self.biases,
                    self.covariance_matrix_sqrts,
                    )
        aux_data = (self.num_states,
                    self.mode,
                    self.l2_penalty,
                    )
        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data, *children)
