"""Prior Network"""
import torch
import torch.nn as nn
from torch.autograd.functional import jacobian
from .mlp import NLayerLeakyMLP, NLayerLeakyNAC, GaussianNet
from .base import GroupLinearLayer
import ipdb as pdb


class MBDTransitionPrior(nn.Module):

    def __init__(self, lags, latent_size, bias=False):
        super().__init__()
        # self.init_hiddens = nn.Parameter(0.001 * torch.randn(lags, latent_size))    
        # out[:,:,0] = (x[:,:,0]@conv.weight[:,:,0].T)+(x[:,:,1]@conv.weight[:,:,1].T) 
        # out[:,:,1] = (x[:,:,1]@conv.weight[:,:,0].T)+(x[:,:,2]@conv.weight[:,:,1].T)
        self.L = lags      
        self.transition = GroupLinearLayer(din = latent_size, 
                                           dout = latent_size, 
                                           num_blocks = lags,
                                           diagonal = False)
        self.bias = bias
        if bias:
            self.b = nn.Parameter(0.001 * torch.randn(1, latent_size))

    def forward(self, x, mask=None):
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        batch_size, length, input_dim = x.shape
        # init_hiddens = self.init_hiddens.repeat(batch_size, 1, 1)
        # x = torch.cat((init_hiddens, x), dim=1)
        x = x.unfold(dimension = 1, size = self.L+1, step = 1)
        x = torch.swapaxes(x, 2, 3)
        shape = x.shape

        x = x.reshape(-1, self.L+1, input_dim)
        xx, yy = x[:,-1:], x[:,:-1]
        if self.bias:
            residuals = torch.sum(self.transition(yy), dim=1) + self.b - xx.squeeze()
        else:
            residuals = torch.sum(self.transition(yy), dim=1) - xx.squeeze()
        residuals = residuals.reshape(batch_size, -1, input_dim)
        # Dummy jacobian matrix (0) to represent identity mapping
        log_abs_det_jacobian = torch.zeros(batch_size, device=x.device)
        return residuals, log_abs_det_jacobian

class NPTransitionPrior(nn.Module):

    def __init__(
        self, 
        lags, 
        latent_size, 
        action_size,
        num_layers=3, 
        hidden_dim=64):
        super().__init__()
        self.L = lags
        # self.init_hiddens = nn.Parameter(0.01 * torch.randn(lags, latent_size))       
        gs = [NLayerLeakyMLP(in_features=lags*(latent_size+action_size)+1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(latent_size)]

        self.gs = nn.ModuleList(gs)

    def forward(self, x, action, masks=None):
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        batch_size, length, input_dim = x.shape
        # init_hiddens = self.init_hiddens.repeat(batch_size, 1, 1)
        # x = torch.cat((init_hiddens, x), dim=1)
        x = x.unfold(dimension = 1, size = self.L+1, step = 1)
        x = torch.swapaxes(x, 2, 3)
        shape = x.shape
        x = x.reshape(-1, self.L+1, input_dim)
        xx, yy = x[:,-1:], x[:,:-1]
        yy = yy.reshape(-1, self.L*input_dim)

        _, _, action_dim = action.shape
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        a = action.unfold(dimension = 1, size = self.L+1, step = 1)
        a = torch.swapaxes(a, 2, 3)
        a = a.reshape(-1, self.L+1, action_dim)[:, :-1]
        a = a.reshape(-1, self.L*action_dim)
        residuals = [ ]
        sum_log_abs_det_jacobian = 0
        for i in range(input_dim):
            if masks is None:
                inputs = torch.cat((yy, a, xx[:,:,i]),dim=-1)
            else:
                mask = masks[i]
                inputs = torch.cat((yy*mask, a, xx[:,:,i]),dim=-1)
            residual = self.gs[i](inputs)
            with torch.enable_grad():
                pdd = jacobian(self.gs[i], inputs, create_graph=True, vectorize=True)
            # Determinant of low-triangular mat is product of diagonal entries
            logabsdet = torch.log(torch.abs(torch.diag(pdd[:,0,:,-1])))
            sum_log_abs_det_jacobian += logabsdet
            residuals.append(residual)

        residuals = torch.cat(residuals, dim=-1)
        residuals = residuals.reshape(batch_size, -1, input_dim)
        sum_log_abs_det_jacobian = torch.sum(sum_log_abs_det_jacobian.reshape(batch_size, length-self.L), dim=1)
        return residuals, sum_log_abs_det_jacobian

class NPDTransitionPrior(nn.Module):
    # latent_size: [s1, s2, s3, s4]
    def __init__(
        self, 
        lags, 
        latent_size_list, 
        action_size,
        num_layers=3, 
        hidden_dim=64):
        super().__init__()
        
        self.L = lags
        # self.init_hiddens = nn.Parameter(0.01 * torch.randn(lags, latent_size)) 
        self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim = latent_size_list[0], latent_size_list[1], latent_size_list[2], latent_size_list[3]
        self.z_dim = sum(latent_size_list)
        gs_z1 = [NLayerLeakyMLP(in_features=lags*(self.z1_dim + self.z2_dim + action_size) + 1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(self.z1_dim)]
        gs_z2 = [NLayerLeakyMLP(in_features=lags*(self.z1_dim + self.z2_dim) + 1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(self.z2_dim)]
        gs_z3 = [NLayerLeakyMLP(in_features=lags*(self.z_dim + action_size)  + 1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(self.z3_dim)]
        gs_z4 = [NLayerLeakyMLP(in_features=lags*(self.z_dim) + 1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(self.z4_dim)]
        gs = gs_z1 + gs_z2 + gs_z3 + gs_z4
        self.gs = nn.ModuleList(gs)
    
    def forward(self, x, action, masks=None):
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        batch_size, length, input_dim = x.shape
        # init_hiddens = self.init_hiddens.repeat(batch_size, 1, 1)
        # x = torch.cat((init_hiddens, x), dim=1)
        x = x.unfold(dimension = 1, size = self.L+1, step = 1)
        x = torch.swapaxes(x, 2, 3)
        shape = x.shape
        x = x.reshape(-1, self.L+1, input_dim)
        # yy: [BS, T-L, L, D]
        xx, yy = x[:,-1:], x[:,:-1]
        z12_l, z34_l = torch.split(yy, [self.z1_dim+self.z2_dim, self.z3_dim+self.z4_dim],dim=-1)
        z12l = z12_l.reshape(-1, self.L*(self.z1_dim+self.z2_dim))
        zl = yy.reshape(-1, self.L*input_dim)
        
        _, _, action_dim = action.shape
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        a = action.unfold(dimension = 1, size = self.L+1, step = 1)
        a = torch.swapaxes(a, 2, 3)
        a = a.reshape(-1, self.L+1, action_dim)[:, :-1]
        a = a.reshape(-1, self.L*action_dim)
        residuals = [ ]
        sum_log_abs_det_jacobian = 0
        for i in range(input_dim):
            if masks is None:
                if i < self.z1_dim:
                    inputs = torch.cat((z12l, a, xx[:,:,i]),dim=-1)
                elif i < self.z1_dim + self.z2_dim:
                    inputs = torch.cat((z12l, xx[:,:,i]),dim=-1)
                elif i < self.z1_dim + self.z2_dim + self.z3_dim:
                    inputs = torch.cat((zl, a, xx[:,:,i]),dim=-1)
                else:
                    inputs = torch.cat((zl, xx[:,:,i]),dim=-1)
            else:
                mask = masks[i]
                if i < self.z1_dim:
                    inputs = torch.cat((z12l, a, xx[:,:,i]),dim=-1)
                elif i < self.z1_dim + self.z2_dim:
                    inputs = torch.cat((z12l, xx[:,:,i]),dim=-1)
                elif i < self.z1_dim + self.z2_dim + self.z3_dim:
                    inputs = torch.cat((zl, a, xx[:,:,i]),dim=-1)
                else:
                    inputs = torch.cat((zl, xx[:,:,i]),dim=-1)
            residual = self.gs[i](inputs)
            with torch.enable_grad():
                pdd = jacobian(self.gs[i], inputs, create_graph=True, vectorize=True)
            # Determinant of low-triangular mat is product of diagonal entries
            logabsdet = torch.log(torch.abs(torch.diag(pdd[:,0,:,-1])))
            sum_log_abs_det_jacobian += logabsdet
            residuals.append(residual)

        residuals = torch.cat(residuals, dim=-1)
        residuals = residuals.reshape(batch_size, -1, input_dim)
        sum_log_abs_det_jacobian = torch.sum(sum_log_abs_det_jacobian.reshape(batch_size, length-self.L), dim=1)
        return residuals, sum_log_abs_det_jacobian
    
class DenoisedTransitionPrior(nn.Module):
    # latent_size: [s1, s2, s3, s4]
    def __init__(
        self, 
        lags, 
        latent_size_list, 
        action_size,
        num_layers=3, 
        hidden_dim=64):
        super().__init__()
        
        self.L = lags
        # self.init_hiddens = nn.Parameter(0.01 * torch.randn(lags, latent_size)) 
        self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim = latent_size_list[0], latent_size_list[1], latent_size_list[2], latent_size_list[3]
        self.z_dim = sum(latent_size_list)
        gs_z1 = [NLayerLeakyMLP(in_features=lags*(self.z1_dim + action_size) + 1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(self.z1_dim)]
        gs_z2 = [NLayerLeakyMLP(in_features=lags*(self.z2_dim) + 1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(self.z2_dim)]
        gs_z3 = [NLayerLeakyMLP(in_features=lags*(self.z3_dim + action_size) + 1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(self.z3_dim)]
        gs_z4 = [NLayerLeakyMLP(in_features=lags*(self.z4_dim) + 1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(self.z4_dim)]
        gs = gs_z1 + gs_z2 + gs_z3 + gs_z4
        self.gs = nn.ModuleList(gs)
    
    def forward(self, x, action, masks=None):
        # TODO need change. 
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        batch_size, length, input_dim = x.shape
        # init_hiddens = self.init_hiddens.repeat(batch_size, 1, 1)
        # x = torch.cat((init_hiddens, x), dim=1)
        x = x.unfold(dimension = 1, size = self.L+1, step = 1)
        x = torch.swapaxes(x, 2, 3)
        shape = x.shape
        x = x.reshape(-1, self.L+1, input_dim)
        # yy: [BS, T-L, L, D]
        xx, yy = x[:,-1:], x[:,:-1]
        z1_l, z2_l, z3_l, z4_l = torch.split(yy, [self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim],dim=-1)
        z1_l = z1_l.reshape(-1, self.L*(self.z1_dim))
        z2_l = z2_l.reshape(-1, self.L*(self.z2_dim))
        z3_l = z3_l.reshape(-1, self.L*(self.z3_dim))
        z4_l = z4_l.reshape(-1, self.L*(self.z4_dim))
        # z12l = z12_l.reshape(-1, self.L*(self.z1_dim+self.z2_dim))
        zl = yy.reshape(-1, self.L*input_dim)
        
        _, _, action_dim = action.shape
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        a = action.unfold(dimension = 1, size = self.L+1, step = 1)
        a = torch.swapaxes(a, 2, 3)
        a = a.reshape(-1, self.L+1, action_dim)[:, :-1]
        a = a.reshape(-1, self.L*action_dim)
        residuals = [ ]
        sum_log_abs_det_jacobian = 0
        for i in range(input_dim):
            if masks is None:
                if i < self.z1_dim:
                    inputs = torch.cat((z1_l, a, xx[:,:,i]),dim=-1)
                elif i < self.z1_dim + self.z2_dim:
                    inputs = torch.cat((z2_l, xx[:,:,i]),dim=-1)
                elif i < self.z1_dim + self.z2_dim + self.z3_dim:
                    inputs = torch.cat((z3_l, a, xx[:,:,i]),dim=-1)
                else:
                    inputs = torch.cat((z4_l, xx[:,:,i]),dim=-1)
            else:
                mask = masks[i]
                if i < self.z1_dim:
                    inputs = torch.cat((z1_l, a, xx[:,:,i]),dim=-1)
                elif i < self.z1_dim + self.z2_dim:
                    inputs = torch.cat((z2_l, xx[:,:,i]),dim=-1)
                elif i < self.z1_dim + self.z2_dim + self.z3_dim:
                    inputs = torch.cat((z3_l, a, xx[:,:,i]),dim=-1)
                else:
                    inputs = torch.cat((z4_l, xx[:,:,i]),dim=-1)
            residual = self.gs[i](inputs)
            with torch.enable_grad():
                pdd = jacobian(self.gs[i], inputs, create_graph=True, vectorize=True)
            # Determinant of low-triangular mat is product of diagonal entries
            logabsdet = torch.log(torch.abs(torch.diag(pdd[:,0,:,-1])))
            sum_log_abs_det_jacobian += logabsdet
            residuals.append(residual)

        residuals = torch.cat(residuals, dim=-1)
        residuals = residuals.reshape(batch_size, -1, input_dim)
        sum_log_abs_det_jacobian = torch.sum(sum_log_abs_det_jacobian.reshape(batch_size, length-self.L), dim=1)
        return residuals, sum_log_abs_det_jacobian
    
class GaussianMLPTransitionPrior(nn.Module):
    # latent_size: [s1, s2, s3, s4]
    def __init__(
        self, 
        lags, 
        latent_size_list, 
        action_size,
        num_layers=3, 
        hidden_dim=64):
        super().__init__()
        
        self.L = lags
        # self.init_hiddens = nn.Parameter(0.01 * torch.randn(lags, latent_size)) 
        self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim = latent_size_list[0], latent_size_list[1], latent_size_list[2], latent_size_list[3]
        self.z_dim = sum(latent_size_list)
        gs_z1 = GaussianNet(input_size=lags*(self.z1_dim + self.z2_dim + action_size), output_size=self.z1_dim, num_layers=num_layers, hidden_dim=hidden_dim)
        gs_z2 = GaussianNet(input_size=lags*(self.z1_dim + self.z2_dim), output_size=self.z2_dim, num_layers=num_layers, hidden_dim=hidden_dim)
        gs_z3 = GaussianNet(input_size=lags*(self.z_dim+action_size), output_size=self.z3_dim, num_layers=num_layers, hidden_dim=hidden_dim)
        gs_z4 = GaussianNet(input_size=lags*(self.z_dim), output_size=self.z4_dim, num_layers=num_layers, hidden_dim=hidden_dim)
        gs = [gs_z1, gs_z2, gs_z3, gs_z4]
        self.gs = nn.ModuleList(gs)
    
    def forward(self, x, action, masks=None):
        # TODO need change. 
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        batch_size, length, input_dim = x.shape
        # init_hiddens = self.init_hiddens.repeat(batch_size, 1, 1)
        # x = torch.cat((init_hiddens, x), dim=1)
        x = x.unfold(dimension = 1, size = self.L+1, step = 1)
        x = torch.swapaxes(x, 2, 3)
        shape = x.shape
        x = x.reshape(-1, self.L+1, input_dim)
        # yy: [BS, T-L, L, D]
        xx, yy = x[:,-1:], x[:,:-1]
        z12_l, z34_l = torch.split(yy, [self.z1_dim+self.z2_dim, self.z3_dim+self.z4_dim],dim=-1)
        z12l = z12_l.reshape(-1, self.L*(self.z1_dim+self.z2_dim))
        zl = yy.reshape(-1, self.L*input_dim)
        
        _, _, action_dim = action.shape
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        a = action.unfold(dimension = 1, size = self.L+1, step = 1)
        a = torch.swapaxes(a, 2, 3)
        a = a.reshape(-1, self.L+1, action_dim)[:, :-1]
        a = a.reshape(-1, self.L*action_dim)
        
        input_z1 = torch.cat((z12l, a), dim=-1)
        input_z2 = z12l
        input_z3 = torch.cat((zl, a), dim=-1)
        input_z4 = zl
        mu_z1, logvar_z1 = self.gs[0](input_z1)
        mu_z2, logvar_z2 = self.gs[1](input_z2)
        mu_z3, logvar_z3 = self.gs[2](input_z3)
        mu_z4, logvar_z4 = self.gs[3](input_z4)
        mu = torch.cat([mu_z1, mu_z2, mu_z3, mu_z4], dim=-1)
        logvar = torch.cat([logvar_z1, logvar_z2, logvar_z3, logvar_z4], dim=-1)
        dist = torch.distributions.Normal(mu, torch.exp(0.5 * logvar))
        return dist

class NPChangeTransitionPrior(nn.Module):

    def __init__(
        self, 
        lags, 
        latent_size,
        embedding_dim, 
        num_layers=3,
        hidden_dim=64):
        super().__init__()
        self.L = lags
        # self.init_hiddens = nn.Parameter(0.01 * torch.randn(lags, latent_size))       
        gs = [NLayerLeakyMLP(in_features=hidden_dim+lags*latent_size+1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(latent_size)]
        
        self.gs = nn.ModuleList(gs)
        self.fc = NLayerLeakyMLP(in_features=embedding_dim,
                                 out_features=hidden_dim,
                                 num_layers=2,
                                 hidden_dim=hidden_dim)

    def forward(self, x, embeddings, masks=None):
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        # embeddings: [BS, embed_dims]
        batch_size, length, input_dim = x.shape
        embeddings = self.fc(embeddings)
        # init_hiddens = self.init_hiddens.repeat(batch_size, 1, 1)
        # x = torch.cat((init_hiddens, x), dim=1)
        x = x.unfold(dimension = 1, size = self.L+1, step = 1)
        x = torch.swapaxes(x, 2, 3)
        shape = x.shape
        x = x.reshape(-1, self.L+1, input_dim)
        xx, yy = x[:,-1:], x[:,:-1]
        yy = yy.reshape(-1, self.L*input_dim)
        residuals = [ ]
        sum_log_abs_det_jacobian = 0
        for i in range(input_dim):
            if masks is None:
                inputs = torch.cat((embeddings, yy, xx[:,:,i]),dim=-1)
            else:
                mask = masks[i]
                inputs = torch.cat((embeddings, yy*mask, xx[:,:,i]),dim=-1)
            residual = self.gs[i](inputs)
            with torch.enable_grad():
                pdd = jacobian(self.gs[i], inputs, create_graph=True, vectorize=True)
            # Determinant of low-triangular mat is product of diagonal entries
            logabsdet = torch.log(torch.abs(torch.diag(pdd[:,0,:,-1])))
            sum_log_abs_det_jacobian += logabsdet
            residuals.append(residual)

        residuals = torch.cat(residuals, dim=-1)
        residuals = residuals.reshape(batch_size, -1, input_dim)
        sum_log_abs_det_jacobian = torch.sum(sum_log_abs_det_jacobian.reshape(batch_size, length-self.L), dim=1)
        return residuals, sum_log_abs_det_jacobian
