{"title": "Combining Generative and Discriminative Models for Hybrid Inference", "book": "Advances in Neural Information Processing Systems", "page_first": 13825, "page_last": 13835, "abstract": "A graphical model is a structured representation of the data generating process. The traditional method to reason over random variables is to perform inference in this graphical model. However, in many cases the generating process is only a poor approximation of the much more complex true data generating process, leading to suboptimal estimation. The subtleties of the generative process are however captured in the data itself and we can ``learn to infer'', that is, learn a direct mapping from observations to explanatory latent variables. In this work we propose a hybrid model that combines graphical inference with a learned inverse model, which we structure as in a graph neural network, while the iterative algorithm as a whole is formulated as a recurrent neural network. By using cross-validation we can automatically balance the amount of work performed by graphical inference versus learned inference. We apply our ideas to the Kalman filter, a Gaussian hidden Markov model for time sequences, and show, among other things, that our model can estimate the trajectory of a noisy chaotic Lorenz Attractor much more accurately than either the learned or graphical inference run in isolation.", "full_text": "Combining Generative and Discriminative Models for\n\nHybrid Inference\n\nVictor Garcia Satorras\nUvA-Bosch Delta Lab\nUniversity of Amsterdam\n\nNetherlands\n\nv.garciasatorras@uva.nl\n\nzeynep.akata@uni-tuebingen.de\n\nZeynep Akata \u21e4\n\nCluster of Excellence ML\nUniversity of T\u00fcbingen\n\nGermany\n\nMax Welling\n\nUvA-Bosch Delta Lab\nUniversity of Amsterdam\n\nNetherlands\n\nm.welling@uva.nl\n\nAbstract\n\nA graphical model is a structured representation of the data generating process.\nThe traditional method to reason over random variables is to perform inference\nin this graphical model. However, in many cases the generating process is only\na poor approximation of the much more complex true data generating process,\nleading to suboptimal estimations. The subtleties of the generative process are\nhowever captured in the data itself and we can \u201clearn to infer\u201d, that is, learn a direct\nmapping from observations to explanatory latent variables. In this work we propose\na hybrid model that combines graphical inference with a learned inverse model,\nwhich we structure as in a graph neural network, while the iterative algorithm as a\nwhole is formulated as a recurrent neural network. By using cross-validation we\ncan automatically balance the amount of work performed by graphical inference\nversus learned inference. We apply our ideas to the Kalman \ufb01lter, a Gaussian\nhidden Markov model for time sequences, and show, among other things, that our\nmodel can estimate the trajectory of a noisy chaotic Lorenz Attractor much more\naccurately than either the learned or graphical inference run in isolation.\n\n1\n\nIntroduction\n\nBefore deep learning, one of the dominant paradigms in machine learning was graphical models\n[4, 27, 21]. Graphical models structure the space of (random) variables by organizing them into a\ndependency graph. For instance, some variables are parents/children (directed models) or neighbors\n(undirected models) of other variables. These dependencies are encoded by conditional probabilities\n(directed models) or potentials (undirected models). While these interactions can have learnable\nparameters, the structure of the graph imposes a strong inductive bias onto the model. Reasoning\nin graphical models is performed by a process called probabilistic inference where the posterior\ndistribution, or the most probable state of a set of variables, is computed given observations of other\nvariables. Many approximate algorithms have been proposed to solve this problem ef\ufb01ciently, among\nwhich are MCMC sampling [29, 33], variational inference [18] and belief propagation algorithms\n[10, 21].\nGraphical models are a kind of generative model where we specify important aspects of the generative\nprocess. They excel in the low data regime because we maximally utilize expert knowledge (a.k.a.\ninductive bias). However, human imagination often falls short of modeling all of the intricate details\nof the true underlying generative process. In the large data regime there is an alternative strategy\nwhich we could call \u201clearning to infer\u201d. Here, we create lots of data pairs {xn, yn} with {yn} the\nobserved variables and {xn} the latent unobserved random variables. These can be generated from\nthe generative model or are available directly in the dataset. Our task is now to learn a \ufb02exible\n\n\u21e4Majority of this work has been done when Zeynep Akata was at the University of Amsterdam.\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fFigure 1: Examples of inferred 5K length trajectories for the Lorenz attractor with t = 0.01 trained\non 50K length trajectory. The mean squared errors from left to right are (Observations: 0.2462, GNN:\n0.0613, E-Kalman Smoother: 0.0372, Hybrid: 0.0169).\n\nmapping q(x|y) to infer the latent variables directly from the observations. This idea is known\nas \u201cinverse modeling\u201d in some communities. It is also known as \u201camortized\u201d inference [32] or\nrecognition networks in the world of variational autoencoders [18] and Helmholtz machines [11].\nIn this paper we consider inference as an iterative message passing scheme over the edges of the\ngraphical model. We know that (approximate) inference in graphical models can be formulated\nas message passing, known as belief propagation, so this is a reasonable way to structure our\ncomputations. When we unroll these messages for N steps we have effectively created a recurrent\nneural network as our computation graph. We will enrich the traditional messages with a learnable\ncomponent that has the function to correct the original messages when there is enough data available.\nIn this way we create a hybrid message passing scheme with prior components from the graphical\nmodel and learned messages from data. The learned messages may be interpreted as a kind of graph\nconvolutional neural network [5, 15, 20].\nOur Hybrid model neatly trades off the bene\ufb01t of using inductive bias in the small data regime and\nthe bene\ufb01t of a much more \ufb02exible and learnable inference network when suf\ufb01cient data is available.\nIn this paper we restrict ourselves to a sequential model known as a hidden Markov process.\n\n2 The Hidden Markov Process\n\nIn this section we brie\ufb02y explain the Hidden Markov Process and how we intend to extend it. In a\nHidden Markov Model (HMM), a set of unobserved variables x = {x1, . . . , xK} de\ufb01ne the state of\na process at every time step 0 < k < K. The set of observable variables from which we want to\ninfer the process states are denoted by y = {y1, . . . yK}. HMMs are used in diverse applications\nas localization, tracking, weather forecasting and computational \ufb01nance among others. (in fact, the\nKalman \ufb01lter was used to land the eagle on the moon.)\nWe can express p(x|y) as the probability distribution of the hidden states given the observations. Our\ngoal is to \ufb01nd which states x maximize this probability distribution. More formally:\n(1)\n\n\u02c6x = arg max\n\np(x|y)\n\nx\n\nUnder the Markov assumption i) the transition model is described by the transition probability\np(xt|xt1), and ii) the measurement model is described by p(yt|xt). Both distributions are stationary\nfor all k. The resulting graphical model can be expressed with the following equation:\n\np(x, y) = p(x0)\n\np(xk|xk1)p(yk|xk)\n\n(2)\n\nKYk=1\n\nOne of the best known approaches for inference problems in this graphical model is the Kalman\nFilter [17] and Smoother [31]. The Kalman Filter assumes both the transition and measurement\ndistributions are linear and Gaussian. The prior knowledge we have about the process is encoded in\nlinear transition and measurement processes, and the uncertainty of the predictions with respect to\nthe real system is modeled by Gaussian noise:\n\n(3)\n(4)\nHere qk, rk come from Gaussian distributions qk \u21e0N (0, Q), rk \u21e0N (0, R). F, H are the linear\ntransition and measurement functions respectively. If the process from which we are inferring x is\n\nxk = Fxk1 + qk\nyk = Hxk + rk\n\n2\n\n\factually Gaussian and linear, a Kalman Filter + Smoother with the right parameters is able to infer\nthe optimal state estimates.\nThe real world is usually non-linear and complex, assuming that a process is linear may be a strong\nlimitation. Some alternatives like the Extended Kalman Filter [24] and the Unscented Kalman\nFilter [34] are used for non-linear estimation, but even when functions are non-linear, they are still\nconstrained to our knowledge about the dynamics of the process which may differ from real world\nbehavior.\nTo model the complexities of the real world we intend to learn them from data through \ufb02exible models\nsuch as neural networks. In this work we present an hybrid inference algorithm that combines the\nknowledge from a generative model (e.g. physics equations) with a function that is automatically\nlearned from data using a neural network. In our experiments we show that this hybrid method\noutperforms the graphical inference methods and also the neural network methods for low and high\ndata regimes respectively. In other words, our method bene\ufb01ts from the inductive bias in the limit of\nsmall data and also the high capacity of a neural networks in the limit of large data. The model is\nshown to gracefully interpolate between these regimes.\n\n3 Related Work\n\nThe proposed method has interesting relations with meta learning [2] since it learns more \ufb02exible\nmessages on top of an existing algorithm. It is also related to structured prediction energy networks\n[3] which are discriminative models that exploit the structure of the output. Structured inference in\nrelational outputs has been effective in a variety of tasks like pose estimation [35], activity recognition\n[12] or image classi\ufb01cation [28]. One of the closest works is Recurrent Inference Machines (RIM)\n[30] where a generative model is also embedded into a Recurrent Neural Network (RNN). However\nin that work graphical models played no role. In the same line of learned recurrent inference, our\noptimization procedure shares similarities with Iterative Amortized Inference [25], although in our\nwork we are re\ufb01ning the gradient using a hybrid setting while they are learning it.\nAnother related line of research is the convergence of graphical models with neural networks, [26]\nreplaced the joint probabilities with trainable factors for time series data. Learning the messages\nin conditional random \ufb01elds has been effective in segmentation tasks [7, 37]. Relatedly, [16] runs\nmessage passing algorithms on top of a latent representation learned by a deep neural network. More\nrecently [36] showed the ef\ufb01cacy of using Graph Neural Networks (GNNs) for inference on a variety\nof graphical models, and compared the performance with classical inference algorithms. This last\nwork is in a similar vein as ours, but in our case, learned messages are used to correct the messages\nfrom graphical inference. In the experiments we will show that this hybrid approach really improves\nover running GNNs in isolation.\nThe Kalman Filter is a widely used algorithm for inference in Hidden Markov Processes. Some\nworks have explored the direction of coupling them with machine learning techniques. A method\nto discriminatively learn the noise parameters of a Kalman Filter was introduced by [1]. In order to\ninput more complex variables, [14] back-propagates through the Kalman Filter such that an encoder\ncan be trained at its input. Similarly, [9] replaces the dynamics de\ufb01ned in the Kalman Filter with\na neural network. In our hybrid model, instead of replacing the already considered dynamics, we\nsimultaneously train a learnable function for the purpose of inference.\n\n4 Model\n\nWe cast our inference model as a message passing scheme where the nodes of a probabilistic graphical\nmodel can send messages to each other to infer estimates of the states x. Our aim is to develop a\nhybrid scheme where messages derived from the generative graphical model are combined with GNN\nmessages:\nGraphical Model Messages (GM-messages): These messages are derived from the generative\ngraphical model (e.g. equations of motion from a physics model).\nGraph Neural Network Messages (GNN-messages): These messages are learned by a GNN which\nis trained to reduce the inference error on labelled data in combination with the GM-messages.\n\n3\n\n\fFigure 2: Graphical illustration of our Hybrid algorithm. The GM-module (blue box) sends messages\nto the GNN-module (red box) which re\ufb01nes the estimation of x.\n\nIn the following two subsections we introduce the two types of messages and the \ufb01nal hybrid inference\nscheme.\n\n4.1 Graphical Model Messages\n\nIn order to de\ufb01ne the GM-messages, we interpret inference as an iterative optimization process to\nestimate the maximum likelihood values of the states x. In its more generic form, the recursive update\nfor each consecutive estimate of x is given by:\n\nx(i+1) = x(i) + rx(i)log(p(x(i), y))\n\nFactorizing equation 5 to the hidden Markov Process from equation 2, we get three input messages\nfor each inferred node xk:\n\nx(i+1)\nk\nM(i)\n\nk\n\nk + M(i)\n= x(i)\nxk1!xk + \u00b5(i)\nk = \u00b5(i)\n@\nlog(p(x(i)\n\n\u00b5(i)\nxk1!xk =\n\n\u00b5(i)\nxk+1!xk =\n\n\u00b5(i)\nyk!xk =\n\n@x(i)\nk\n@\n\n@x(i)\nk\n@\n\n@x(i)\nk\n\nyk!xk\n\nxk+1!xk + \u00b5(i)\nk |x(i)\nk1))\nk+1|x(i)\nk ))\n\nlog(p(x(i)\n\nlog(p(yk|x(i)\nk ))\n\n(5)\n\n(6)\n\n(7)\n\n(8)\n\n(9)\n\nThese messages can be obtained by computing the three derivatives from equations 7, 8, 9. It is\noften assumed that the transition and measurement distributions p(xk|xk1), p(yk|xk) are linear and\nGaussian (e.g. Kalman Filter model). Next, we provide the expressions of the GM-messages when\nassuming the linear and Gaussian functions from equations 3, 4:\n\n\u00b5xk1!xk = Q1(xk Fxk1)\n\u00b5xk+1!xk = FT Q1(xk+1 Fxk)\n\u00b5yk!xk = HT R1(yk Hxk)\n\n(10)\n(11)\n(12)\n\n4.2 Adding GNN-messages\nWe call v the collection of nodes of the graphical model v = x[ y. We also de\ufb01ne an equivalent\ngraph where the GNN operates by propagating the GNN messages. We build the following mappings\nfrom the nodes of the graphical model to the nodes of the GNN: hx = {(x) : x 2 x}, hy = {(y) :\ny 2 y}. Analogously, the union of both collections would be hv = hx [ hy. Therefore, each node\nof the graphical model has a corresponding node h in the GNN. The edges for both graphs are also\nequivalent. Values of h(0)\nx that correspond to unobserved variables x are randomly initialized. Instead,\nvalues h(0)\n\ny are obtained by forwarding yk through a linear layer.\n\n4\n\n\fNext we present the equations of the learned messages, which consist of a GNN message passing\noperation. Similarly to [23, 19], a GRU [8] is added to the message passing operation to make it\nrecursive:\nm(i)\nk,n = zk,nfe(h(i)\nm(i)\nU(i)\nk,n\n\n(message from GNN nodes to edge factor)\n(message from edge factors to GNN node)\n\nvn , \u00b5vn!xk )\n\n(13)\n(14)\n\nxk , h(i)\n\nk = Xvn6=xk\n\n)\n\nxk\n\nk , h(i)\nxk )\n\nxk = GRU(U(i)\nh(i+1)\n\u270f(i+1)\n= fdec(h(i+1)\nk\n\n(RNN update)\n(computation of correction factor)\n\n(15)\n(16)\nEach GNN message is computed by the function fe(\u00b7), which receives as input two hidden states\nfrom the last recurrent iteration, and their corresponding GM-message, this function is different for\neach type of edge (e.g. transition or measurement for the HMM). zk,n takes value 1 if there is an edge\nbetween vn and xk, otherwise its value is 0. The sum of messages U(i)\nk is provided as input to the\nGRU function that updates each hidden state h(i)\nxk for each node. The GRU is composed by a single\nGRU cell preceded by a linear layer at its input. Finally a correction signal \u270f(i+1)\nis decoded from\neach hidden state h(i+1)\nand it is added to the recursive operation 6, resulting in the \ufb01nal equation:\n\nk\n\nxk\n\nk\n\n)\n\n= x(i)\n\nx(i+1)\nk\n\nk + \u270f(i+1)\n\nk + (M(i)\n\n(17)\nIn summary, equation 17 de\ufb01nes our hybrid model in a simple recursive form where xk is updated\nthrough two contributions: one that relies on the probabilistic graphical model messages M(i)\nk , and\n\u270f(i)\nk , that is automatically learned. We note that it is important that the GNN messages model the\n\"residual error\" of the GM inference process, which is often simpler than modeling the full signal. A\nvisual representation of the algorithm is shown in Figure 2.\nIn the experimental section of this work we apply our model to the Hidden Markov Process, however,\nthe above mentioned GNN-messages are not constrained to this particular graphical structure. The\nGM-messages can also be obtained for other arbitrary graph structures by applying the recursive\ninference equation 5 to their respective graphical models.\n\n4.3 Training procedure\nIn order to provide early feedback, the loss function is computed at every iteration with a weighted\nsum that emphasizes later iterations, wi = i\n\nN , more formally:\n\nLoss(\u21e5) =\n\nwiL(gt, (x(i)))\n\n(18)\n\nNXi=1\n\nWhere function (\u00b7) extracts the part of the hidden state x contained in the ground truth gt. In our\nexperiments we use the mean square error for L(\u00b7). The training procedure consists of three main\nsteps. First, we initialize x(0)\nat the value that maximizes p(yk|xk). For example, in a trajectory\nk\nestimation problem we set the position values of xk as the observed positions yk. Second, we tune\nthe hyper-parameters of the graphical model as it would be done with a Kalman Filter, which are\nusually the variance of Gaussian distributions. Finally, we train the model using the above mentioned\nloss (equation 18).\n\n5 Experiments\n\nIn this section we compare our Hybrid model with the Kalman Smoother and a recurrent GNN. We\nshow that our Hybrid model can leverage the bene\ufb01ts of both methods for different data regimes.\nNext we de\ufb01ne the models used in the experiments 2:\nKalman Smoother: The Kalman Smoother is the widely known Kalman Filter algorithm [17] +\nthe RTS smoothing step [31]. In experiments where the transition function is non-linear we use the\n\n2Available at: https://github.com/vgsatorras/hybrid-inference\n\n5\n\n\fk = H>yk + fdec(h(i)\n\nExtended Kalman Filter + smoothing step which we will call \u201cE-Kalman Smoother\u201d.\nGM-messages: As a special case of our hybrid model we propose to remove the learned signal \u270f(i)\nk\nand base our predictions only on the graphical model messages from eq. 6.\nGNN-messages: The GNN model is another special case of our model when all the GM-messages\nare removed and only GNN messages are propagated. Instead of decoding a re\ufb01nement for the current\nx(i)\nk estimate, we directly estimate: x(i)\nxk ). The resulting algorithm is equivalent\nto a Gated Graph Neural Network [23].\nHybrid model: This is our full model explained in section 4.2.\nWe set = 0.005 and use the Adam optimizer with a learning rate 103. The number of inference\niterations used in the Hybrid model, GNN-messages and GM-messages is N=50. fe and fdec are a\n2-layers MLPs with Leaky Relu and Relu activations respectively. The number of features in the\nhidden layers of the GRU, fe and fdec is nf=48. In trajectory estimation experiments, yk values may\ntake any value from the real numbers R. Shifting a trajectory to a non-previously seen position may\nhurt the generalization performance of the neural network. To make the problem translation invariant\nwe modify yk before mapping it to hyk, we use the difference between the observed current position\nwith the previous one and with the next one.\n\n5.1 Linear dynamics\n\nThe aim of this experiment is to infer the position of every node in trajectories generated by linear\nand gaussian equations. The advantage of using a synthetic environment is that we know in advance\nthe original equations the motion pattern was generated from, and by providing the right linear and\ngaussian equations to a Kalman Smoother we can obtain the optimal inferred estimate as a lower\nbound of the test loss.\nAmong other tasks, Kalman Filters are used to re\ufb01ne the noisy measurement of GPS systems. A\nphysics model of the dynamics can be provided to the graphical model that, combined with the noisy\nmeasurements, gives a more accurate estimation of the position. The real world is usually more\ncomplex than the equations we may provide to our graphical model, leading to a gap between the\nassumed dynamics and the real world dynamics. Our hybrid model is able to \ufb01ll this gap without the\nneed to learn everything from scratch.\nTo show that, we generate synthetic trajectories T = {x, y}. Each state xk 2 R6 is a 6-dimensional\nvector that encodes position, velocity and acceleration (p, v, a) for two dimensions. Each yk 2 R2 is a\nnoisy measurement of the position also for two dimensions. The transition dynamic is a non-uniform\naccelerated motion that also considers drag (air resistance):\n\n@p\n@t\n\n= v,\n\n@v\n@t\n\n@a\n@t\n\n= \u2327v\n\n= a cv,\n\n(19)\nWhere cv represents the air resistance [13], with c being a constant that depends on the properties\nof the \ufb02uid and the object dimensions. Finally, the variable \u2327v is used to non-uniformly accelerate\nthe object.\nTo generate the dataset, we sample from the Markov\nprocess of equation 2 where the transition proba-\nbility distribution p(xk+1|xk) and the measurement\nprobability distribution p(yk|xk) follow equations\n(3, 4). Values F, Q, H, R for these distributions are\ndescribed in the Appendix, in particular, F is analyti-\ncally obtained from the above mentioned differential\nequations 19. We sample two different motion tra-\njectories from 50 to 100K time steps each, one for\nvalidation and the other for training. An additional\n10K time steps trajectory is sampled for testing. The\nsampling time step is t = 1.\nAlternatively, the graphical model of the algorithm\nis limited to a uniform motion pattern p = p0 + vt.\nIts equivalent differential equations form would be\n@p\n@t = v. Notice that the air friction is not considered\n\nFigure 3: MSE comparison with respect to\nthe number of training samples for the linear\ndynamics dataset.\n\n6\n\n\fanymore and velocity and acceleration are assumed to be uniform. Again the parameters for the\nmatrices F, Q, H, R when considering a uniform motion pattern are analytically obtained and\ndescribed in the Appendix.\n\nResults. The Mean Square Error with respect to the number of training samples is shown for\ndifferent algorithms in Figure 3. The plot shows the average and the standard deviation over 7 runs,\nthe sampled test trajectory remains the same over all runs, this is not the case for the training and\nvalidation sampled trajectories. Note that the MSE of the Kalman Smoother and GM-messages\noverlap in the plot since both errors were exactly the same.\nOur model outperforms both the GNN or Kalman Smoother in isolation in all data regimes, and it has\na signi\ufb01cant edge over the Kalman Smoother when the number of samples is larger than 1K. This\nshows that our model is able to ensemble the advantages of prior knowledge and deep learning in a\nsingle framework. These results show that our hybrid model bene\ufb01ts from the inductive bias of the\ngraphical model equations when data is scarce, and simultaneously it bene\ufb01ts from the \ufb02exibility of\nthe GNN when data is abound.\nA clear trade-off can be observed between the Kalman smoother and the GNN. The Kalman Smoother\nclearly performs better for low data regimes, while the GNN outperforms it for larger amounts of\ndata (>10K). The hybrid model is able to bene\ufb01t from the strengths of both.\n\n5.2 Lorenz Attractor\nThe Lorenz equations describe a non-linear chaotic system used for atmospheric convection. Learning\nthe dynamics of this chaotic system in a supervised way is expected to be more challenging than\nfor linear dynamics, making it an interesting evaluation of our Hybrid model. A Lorenz system is\nmodelled by three differential equations that de\ufb01ne the convection rate, the horizontal temperature\nvariation and the vertical temperature variation of a \ufb02uid:\n\n@z1\n@t\n\n= 10(z2 z1),\n\n@z2\n@t\n\n= z1(28 z3) z2,\n\n@z3\n@t\n\n= z1z2 \n\n8\n3\n\nz3\n\n(20)\n\nTo generate a trajectory we run the Lorenz equations 20 with a dt = 105 from which we sample\nwith a time step of t = 0.05 resulting in a single trajectory of 104K time steps. Each point is then\nperturbed with gaussian noise of standard deviation = 0.5. From this trajectory, 4K time steps are\nseparated for testing, the remaining trajectory of 100K time steps is equally split between training\nand validation partitions.\nAssuming x 2 R3 is a 3-dimensional vector x = [z1, z2, z3]>, we can write down the dynamics\nmatrix of the system as A|x from the Lorenz differential eq. 20, and obtain the transition function\nF|xk [22] using the Taylor Expansion.\n\n35\"z1\nz3# , F|xk = I +\n\nz2\n\n(A|xk t)j\n\nj!\n\nJXj=1\n\n(21)\n\n\u02d9x = A|xx =24\n\n10\n10\n28 z3 1\nz2\n\n0\n0\n0 8\n\n3\n\nwhere I is the identity matrix and J is the number of\nterms from the Taylor expansion. We run simulations\nfor J=1, J=2 and J=5. For larger J the improvement\nwas minimal. For the measurement model H = I we\nuse the identity matrix. For the noise distributions\nQ = 2tI and R = 0.52I we use diagonal ma-\ntrices. The only hyper-parameter to tune from the\ngraphical model is .\nSince the dynamics are non-linear, the matrix F|xk\ndepends on the values xk. The presence of these\nvariables inside the matrix introduces a simple non-\nlinearity that makes the function much harder to learn.\n\nResults. The results in Figure 4 show that the GNN\nstruggles to achieve low accuracies for this chaotic\n\nFigure 4: MSE with respect to the the number\nof training samples on the Lorenz Attractor.\n\n7\n\n\fsystem, i.e. it does not converge together with the hybrid model even when the training dataset\ncontains up to 105 samples and the hybrid loss is already 0.01 \u21e0 0.02. We attribute this dif\ufb01culty to\nthe fact the matrix F|xk is different at every state xk, becoming harder to approximate.\nThis behavior is different from the previous experiment (linear dynamics) where both the Hybrid\nmodel and the GNN converged to the optimal solution for high data regimes. In this experiment, even\nwhen the GNN and the E-Kalman Smoother perform poorly, the Hybrid model gets closer to the\noptimal solution, outperforming both of them in isolation. This shows that the Hybrid model bene\ufb01ts\nfrom the labeled data even in situations where its fully-supervised variant or the E-Kalman Smoother\nare unable to properly model the process. One reason for this could be that the residual dynamics (i.e.\nthe error of the E-Kalman Smoother) are much more linear than the original dynamics and hence\neasier to model by the GNN.\nAs can be seen in Figure 4, depending on the amount of prior knowledge used in our hybrid model\nwe will need more or less samples to achieve a particular accuracy. Following, we show in table 5.2\nthe approximate number of training samples required to achieve accuracies 0.1 and 0.05 depending\non the amount of knowledge we provide (i.e. the number of J terms of the Taylor expansion). The\nhybrid method requires \u21e0 10 times less samples than the fully-learned method for MSE=0.1 and\n\u21e0 20 times less samples for MSE=0.05.\n\nGNN (J=0) Hybrid (J=1) Hybrid (J=2 & J=5)\nMSE = 0.1\n\u21e0 5.000\nMSE = 0.05 \u21e0 90.000\n\n\u21e0 400\n\u21e0 4.000\n\n\u21e0 500\n\u21e0 5.000\n\nTable 1: Number of samples required to achieve a particular MSE depending on the amount of prior\nknowledge (i.e. J). These numbers have been extracted from Figure 4.\n\nQualitative results of estimated trajectories by the different algorithms on the Lorenz attractor are\ndepicted in Figure 1. The plots correspond to a 5K length test trajectory (with t = 0.01). All\ntrainable algorithms have been trained on 5K length trajectories.\n\n5.3 Real World Dynamics: Michigan NCLT dataset\n\nTo demonstrate the generalizability of our Hybrid model to real world datasets, we use the\nMichigan NCLT [6] dataset which is collected by a segway robot moving around the Uni-\nversity of Michigan\u2019s North Campus.\nIt comprises different trajectories where the GPS mea-\nsurements and the ground truth location of the robot are provided. Given these noisy GPS\nobservations, our goal is to infer a more accurate position of the segway at a given time.\n\nALGORITHM\nOBSERVATIONS (BASELINE)\nKALMAN SMOOTHER\nGM-MESSAGES\nGNN-MESSAGES\nHYBRID MODEL\n\nIn our experiments we arbitrarily use the session\nwith date 2012-01-22 which consists of a single\ntrajectory of 6.1 Km on a cloudy day. Sampling\nat 1Hz results in 4.629 time steps and after re-\nmoving the parts with a unstable GPS signal,\n4.344 time steps remain. Finally, we split the\ntrajectory into three sections: 1.502 time steps\nfor training, 1.469 for validation and 1.373 for\ntesting. The GPS measurements are assumed to\nbe the noisy measurements denoted by y.\nFor the transition and measurement graphical model distributions we assume the same uniform\nmotion model used in section 5.1, speci\ufb01cally the dynamics of a uniform motion pattern. The only\nparameters to learn from the graphical model will be the variance from the measurement and transition\ndistributions. The detailed equations are presented in the Appendix.\n\nTable 2: MSE for different methods on the Michi-\ngan NCLT datset.\n\nMSE\n3.4974\n3.0099\n3.0048\n1.7929\n1.4109\n\nResults. Our results show that our Hybrid model (1.4109 MSE) outperforms the GNN (1.7929\nMSE), the Kalman Smoother (3.0099 MSE) and the GM-messages (3.0048 MSE). One of the\nadvantages of the GNN and the Hybrid methods on real world datasets is that both can model the\ncorrelations through time from the noise distributions while the GM-messages and the Kalman\nSmoother assume the noise to be uncorrelated through time as it is de\ufb01ned in the graphical model. In\n\n8\n\n\fsummary, this experiment shows that our hybrid model can generalize with good performance to a\nreal world dataset.\n\n6 Discussion\n\nIn this work, we explored the combination of recent advances in neural networks (e.g. graph neural\nnetworks) with more traditional methods of graphical inference in hidden Markov models for time\nseries. The result is a hybrid algorithm that bene\ufb01ts from the inductive bias of graphical models and\nfrom the high \ufb02exibility of neural networks. We demonstrated these bene\ufb01ts in three different tasks\nfor trajectory estimation, a linear dynamics dataset, a non-linear chaotic system (Lorenz attractor)\nand a real world positioning system. In three experiments, the Hybrid method learns to ef\ufb01ciently\ncombine graphical inference with learned inference, outperforming both when run in isolation.\nPossible future directions include applying our idea to other generative models. The equations that\ndescribe our hybrid model are de\ufb01ned on edges and nodes, therefore, by modifying the input graph,\ni.e. by modifying the edges and nodes of the input graph, we can run our algorithm on arbitrary graph\nstructures. Other future directions include exploring hybrid methods for performing probabilistic\ninference in other graphical models (e.g. discrete variables), as well learning the graphical model\nitself. In this work we used cross-validation to make sure we did not over\ufb01t the GNN part of the\nmodel to the data at hand, optimally balancing prior knowledge and data-driven inference. In the\nfuture we intend to explore a more principled Bayesian approach to this. Finally, hybrid models like\nthe one presented on this paper can help improve the interpretability of model predictions due to their\ngraphical model backbone.\n\nReferences\n[1] P. Abbeel, A. Coates, M. Montemerlo, A. Y. Ng, and S. Thrun. Discriminative training of\n\nkalman \ufb01lters. In Robotics: Science and systems, volume 2, page 1, 2005.\n\n[2] M. Andrychowicz, M. Denil, S. Gomez, M. W. Hoffman, D. Pfau, T. Schaul, B. Shillingford,\nand N. De Freitas. Learning to learn by gradient descent by gradient descent. In Advances in\nNeural Information Processing Systems, pages 3981\u20133989, 2016.\n\n[3] D. Belanger, B. Yang, and A. McCallum. End-to-end learning for structured prediction energy\nnetworks. In Proceedings of the 34th International Conference on Machine Learning-Volume\n70, pages 429\u2013439. JMLR. org, 2017.\n\n[4] C. M. Bishop. Pattern recognition and machine learning. springer, 2006.\n\n[5] J. Bruna, W. Zaremba, A. Szlam, and Y. LeCun. Spectral networks and locally connected\n\nnetworks on graphs. arXiv preprint arXiv:1312.6203, 2013.\n\n[6] N. Carlevaris-Bianco, A. K. Ushani, and R. M. Eustice. University of michigan north campus\nlong-term vision and lidar dataset. The International Journal of Robotics Research, 35(9):1023\u2013\n1035, 2016.\n\n[7] L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, and A. L. Yuille. Semantic image segmen-\ntation with deep convolutional nets and fully connected crfs. arXiv preprint arXiv:1412.7062,\n2014.\n\n[8] J. Chung, C. Gulcehre, K. Cho, and Y. Bengio. Empirical evaluation of gated recurrent neural\n\nnetworks on sequence modeling. arXiv preprint arXiv:1412.3555, 2014.\n\n[9] H. Coskun, F. Achilles, R. DiPietro, N. Navab, and F. Tombari. Long short-term memory kalman\n\ufb01lters: Recurrent neural estimators for pose regularization. arXiv preprint arXiv:1708.01885,\n2017.\n\n[10] C. Crick and A. Pfeffer. Loopy belief propagation as a basis for communication in sensor\nnetworks. In Proceedings of the Nineteenth conference on Uncertainty in Arti\ufb01cial Intelligence,\npages 159\u2013166. Morgan Kaufmann Publishers Inc., 2002.\n\n9\n\n\f[11] P. Dayan, G. E. Hinton, R. M. Neal, and R. S. Zemel. The helmholtz machine. Neural\n\ncomputation, 7(5):889\u2013904, 1995.\n\n[12] Z. Deng, A. Vahdat, H. Hu, and G. Mori. Structure inference machines: Recurrent neural\nnetworks for analyzing relations in group activity recognition. In Proceedings of the IEEE\nConference on Computer Vision and Pattern Recognition, pages 4772\u20134781, 2016.\n\n[13] G. Falkovich. Fluid mechanics: A short course for physicists. Cambridge University Press,\n\n2011.\n\n[14] T. Haarnoja, A. Ajay, S. Levine, and P. Abbeel. Backprop kf: Learning discriminative determin-\nistic state estimators. In Advances in Neural Information Processing Systems, pages 4376\u20134384,\n2016.\n\n[15] M. Henaff, J. Bruna, and Y. LeCun. Deep convolutional networks on graph-structured data.\n\narXiv preprint arXiv:1506.05163, 2015.\n\n[16] M. Johnson, D. K. Duvenaud, A. Wiltschko, R. P. Adams, and S. R. Datta. Composing graphical\nmodels with neural networks for structured representations and fast inference. In Advances in\nneural information processing systems, pages 2946\u20132954, 2016.\n\n[17] R. E. Kalman. A new approach to linear \ufb01ltering and prediction problems. Journal of basic\n\nEngineering, 82(1):35\u201345, 1960.\n\n[18] D. P. Kingma and M. Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114,\n\n2013.\n\n[19] T. Kipf, E. Fetaya, K.-C. Wang, M. Welling, and R. Zemel. Neural relational inference for\n\ninteracting systems. arXiv preprint arXiv:1802.04687, 2018.\n\n[20] T. N. Kipf and M. Welling. Semi-supervised classi\ufb01cation with graph convolutional networks.\n\narXiv preprint arXiv:1609.02907, 2016.\n\n[21] D. Koller, N. Friedman, and F. Bach. Probabilistic graphical models: principles and techniques.\n\nMIT press, 2009.\n\n[22] R. Labbe. Kalman and bayesian \ufb01lters in python, 2014.\n\n[23] Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel. Gated graph sequence neural networks.\n\narXiv preprint arXiv:1511.05493, 2015.\n\n[24] L. Ljung. Asymptotic behavior of the extended kalman \ufb01lter as a parameter estimator for linear\n\nsystems. IEEE Transactions on Automatic Control, 24(1):36\u201350, 1979.\n\n[25] J. Marino, Y. Yue, and S. Mandt. Iterative amortized inference. arXiv preprint arXiv:1807.09356,\n\n2018.\n\n[26] P. Mirowski and Y. LeCun. Dynamic factor graphs for time series modeling. In Joint European\nConference on Machine Learning and Knowledge Discovery in Databases, pages 128\u2013143.\nSpringer, 2009.\n\n[27] K. P. Murphy. A probabilistic perspective. 2012.\n\n[28] N. Nauata, H. Hu, G.-T. Zhou, Z. Deng, Z. Liao, and G. Mori. Structured label inference for\n\nvisual understanding. arXiv preprint arXiv:1802.06459, 2018.\n\n[29] R. M. Neal et al. Mcmc using hamiltonian dynamics. Handbook of Markov Chain Monte Carlo,\n\n2(11):2, 2011.\n\n[30] P. Putzky and M. Welling. Recurrent inference machines for solving inverse problems. arXiv\n\npreprint arXiv:1706.04008, 2017.\n\n[31] H. E. Rauch, C. Striebel, and F. Tung. Maximum likelihood estimates of linear dynamic systems.\n\nAIAA journal, 3(8):1445\u20131450, 1965.\n\n10\n\n\f[32] D. J. Rezende and S. Mohamed. Variational inference with normalizing \ufb02ows. arXiv preprint\n\narXiv:1505.05770, 2015.\n\n[33] T. Salimans, D. Kingma, and M. Welling. Markov chain monte carlo and variational inference:\nBridging the gap. In International Conference on Machine Learning, pages 1218\u20131226, 2015.\n[34] E. A. Wan and R. Van Der Merwe. The unscented kalman \ufb01lter for nonlinear estimation.\nIn Adaptive Systems for Signal Processing, Communications, and Control Symposium 2000.\nAS-SPCC. The IEEE 2000, pages 153\u2013158. Ieee, 2000.\n\n[35] S.-E. Wei, V. Ramakrishna, T. Kanade, and Y. Sheikh. Convolutional pose machines.\n\nIn\nProceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages\n4724\u20134732, 2016.\n\n[36] K. Yoon, R. Liao, Y. Xiong, L. Zhang, E. Fetaya, R. Urtasun, R. Zemel, and X. Pitkow. Inference\nin probabilistic graphical models by graph neural networks. arXiv preprint arXiv:1803.07710,\n2018.\n\n[37] S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang, and P. H.\nTorr. Conditional random \ufb01elds as recurrent neural networks. In Proceedings of the IEEE\ninternational conference on computer vision, pages 1529\u20131537, 2015.\n\n11\n\n\f", "award": [], "sourceid": 7726, "authors": [{"given_name": "Victor", "family_name": "Garcia Satorras", "institution": "University of Amsterdam"}, {"given_name": "Zeynep", "family_name": "Akata", "institution": "University of Amsterdam"}, {"given_name": "Max", "family_name": "Welling", "institution": "University of Amsterdam / Qualcomm AI Research"}]}