{"title": "Just-In-Time Learning for Fast and Flexible Inference", "book": "Advances in Neural Information Processing Systems", "page_first": 154, "page_last": 162, "abstract": "Much of research in machine learning has centered around the search for inference algorithms that are both general-purpose and efficient. The problem is extremely challenging and general inference remains computationally expensive. We seek to address this problem by observing that in most specific applications of a model, we typically only need to perform a small subset of all possible inference computations. Motivated by this, we introduce just-in-time learning, a framework for fast and flexible inference that learns to speed up inference at run-time. Through a series of experiments, we show how this framework can allow us to combine the flexibility of sampling with the efficiency of deterministic message-passing.", "full_text": "Just-In-Time Learning for Fast and Flexible Inference\n\nS. M. Ali Eslami, Daniel Tarlow, Pushmeet Kohli and John Winn\n\n{alie,dtarlow,pkohli,jwinn}@microsoft.com\n\nMicrosoft Research\n\nAbstract\n\nMuch of research in machine learning has centered around the search for inference\nalgorithms that are both general-purpose and ef\ufb01cient. The problem is extremely\nchallenging and general inference remains computationally expensive. We seek to\naddress this problem by observing that in most speci\ufb01c applications of a model,\nwe typically only need to perform a small subset of all possible inference com-\nputations. Motivated by this, we introduce just-in-time learning, a framework for\nfast and \ufb02exible inference that learns to speed up inference at run-time. Through\na series of experiments, we show how this framework can allow us to combine the\n\ufb02exibility of sampling with the ef\ufb01ciency of deterministic message-passing.\n\n1\n\nIntroduction\n\nWe would like to live in a world where we can de\ufb01ne a probabilistic model, press a button, and\nget accurate inference results within a matter of seconds or minutes. Probabilistic programming\nlanguages allow for the rapid de\ufb01nition of rich probabilistic models to this end, but they also raise a\ncrucial question: what algorithms can we use to ef\ufb01ciently perform inference for the largest possible\nset of programs in the language? Much of recent research in machine learning has centered around\nthe search for inference algorithms that are both \ufb02exible and ef\ufb01cient.\nThe general inference problem is extremely challenging and remains computationally expensive.\nSampling based approaches (e.g. [5, 19]) can require many evaluations of the probabilistic program\nto obtain accurate inference results. Message-passing based approaches (e.g. [12]) are typically\nfaster, but require the program to be expressed in terms of functions for which ef\ufb01cient message-\npassing operators have been implemented. However, implementing a message-passing operator for\na new function either requires technical expertise, or is computationally expensive, or both.\nIn this paper we propose a solution to this problem that is automatic (it doesn\u2019t require the user\nto build message passing operators) and ef\ufb01cient (it learns from past experience to make future\ncomputations faster). The approach is motivated by the observation that general algorithms are\nsolving problems that are harder than they need to be: in most speci\ufb01c inference problems, we only\never need to perform a small subset of all possible message-passing computations. For example,\nin Expectation Propagation (EP) the range of input messages to a logistic factor, for which it needs\nto compute output messages, is highly problem speci\ufb01c (see Fig. 1a). This observation raises the\ncentral question of our work: can we automatically speed up the computations required for general\nmessage-passing, at run-time, by learning about the statistics of the speci\ufb01c problems at hand?\nOur proposed framework, which we call just-in-time learning (JIT learning), initially uses highly\ngeneral algorithms for inference. It does so by computing messages in a message-passing algorithm\nusing Monte Carlo sampling, freeing us from having to implement hand-crafted message update\noperators. However, it also gradually learns to increase the speed of these computations by regress-\ning from input to output messages (in a similar way to [7]) at run-time. JIT learning enables us\nto combine the \ufb02exibility of sampling (by allowing arbitrary factors) and the speed of hand-crafted\nmessage-passing operators (by using regressors), without having to do any pre-training. This con-\nstitutes our main contribution and we describe the details of our approach in Sec. 3.\n\n1\n\n\fGP\n\nf soil\n\nxi\n\nEval\n\nai\n\n+\n\nti\n\nGP\n\nf seed\n\nEval\n\nymax\ni\n\nYield\n\ntopt\ni\n\nyavg\ni\n\nNoise\n\nyi\n\n(a) Problem-speci\ufb01c variation\n\n(b) Random forest uncertainty\n\n(c)\n\nFigure 1: (a) Parameters of Gaussian messages input to a logistic factor in logistic regression vary\nsigni\ufb01cantly in four random UCI datasets.\n(b) Figure for Sec. 4: A regression forest performs\n1D regression (1,000 trees, 2 feature samples per node, maximum depth 4, regressor polynomial\ndegree 2). The red shaded area indicates one standard deviation of the predictions made by the\ndifferent trees in the forest, indicating its uncertainty. (c) Figure for Sec. 6: The yield factor relates\ntemperatures and yields recorded at farms to the optimal temperatures of their planted grain. JIT\nlearning enables us to incorporate arbitrary factors with ease, whilst maintaining inference speed.\n\nOur implementation relies heavily on the use of regressors that are aware of their own uncertainty.\nTheir awareness about the limits of their knowledge allows them to decide when to trust their pre-\ndictions and when to fall back to computationally intensive Monte Carlo sampling (similar to [8]\nand [9]). We show that random regression forests [4] form a natural and ef\ufb01cient basis for this\nclass of \u2018uncertainty aware\u2019 regressors and we describe how they can be modi\ufb01ed for this purpose in\nSec. 4. To the best of our knowledge this is the \ufb01rst application of regression forests to the self-aware\nlearning setting and it constitutes our second contribution.\nTo demonstrate the ef\ufb01cacy of the JIT framework, we employ it for inference in a variety of graphical\nmodels. Experimental results in Sec. 6 show that for general graphical models, our approach leads\nto signi\ufb01cant improvements in inference speed (often several orders of magnitude) over importance\nsampling whilst maintaining overall accuracy, even boosting performance for models where hand\ndesigned EP message-passing operators are available. Although we demonstrate JIT learning in the\ncontext of expectation propagation, the underlying ideas are general and the framework can be used\nfor arbitrary inference problems.\n\n2 Background\n\nables x = {x1, ..., xV } via non-negative factors \u03c81, ..., \u03c8F given by p(x) = (cid:81)\n\nA wide class of probabilistic models can be represented using the framework of factor graphs. In this\ncontext a factor graph represents the factorization of the joint distribution over a set of random vari-\nf \u03c8f (xne(\u03c8f ))/Z,\nwhere xne(\u03c8f ) is the set of variables that factor \u03c8f is de\ufb01ned over. We will focus on directed factors\nof the form \u03c8(xout|xin) which directly specify the conditional density over the output variables xout\nas a function of the inputs xin, although our approach can be extended to factors of arbitrary form.\nBelief propagation (or sum-product) is a message-passing algorithm for performing inference in fac-\ntor graphs with discrete and real-valued variables, and it includes sub-routines that compute variable-\nto-factor and factor-to-variable messages. The bottleneck is mainly in computing the latter kind, as\nthey often involve intractable integrals. The message from factor \u03c8 to variable i is:\n\nm\u03c8\u2192i(xi) =\n\n\u03c8(xout|xin)\n\nx\u2212i\n\nk\u2208ne(\u03c8)\\i\n\nmk\u2192\u03c8(xk),\n\n(1)\n\n(cid:89)\n\n(cid:90)\n\n(cid:104)(cid:82)\n\nwhere x\u2212i denotes all random variables in xne(\u03c8) except i. To further complicate matters, the\nmessages are often not even representable in a compact form. Expectation Propagation [11] extends\nthe applicability of message-passing algorithms by projecting messages back to a pre-determined,\ntractable family distribution:\n\n(cid:105)\n\nproj\n\nm\u03c8\u2192i(xi) =\n\nx\u2212i\n\nk\u2208ne(\u03c8) mk\u2192\u03c8(xk)\n\n.\n\n(2)\n\n\u03c8(xout|xin)(cid:81)\n\nmi\u2192\u03c8(xi)\n\n2\n\n\u221240\u221230\u221220\u221210010\u22124\u2212202468MeanLog precision banknote_authenticationblood_transfusionionospherefertility_diagnosis\u221210\u221250510\u22120.500.51 Training datapointsForest predictions\fThe proj[\u00b7] operator ensures that the message is a distribution of the correct type and only has an\neffect if its argument is outside the approximating family used for the target message.\nThe integral in the numerator of Eq. 2 can be computed using Monte Carlo methods [2, 7], e.g. by\nusing the generally applicable technique of importance sampling. After multiplying and dividing by\na proposal distribution q(xin) we get:\nm\u03c8\u2192i(xi) \u2261 proj\n\nv(xin, xout) \u00b7 w(xin, xout)\n\n/mi\u2192\u03c8(xi),\n\n(cid:34)(cid:90)\n\n(cid:35)\n\n(3)\n\nwhere v(xin, xout) = q(xin)\u03c8(xout|xin) and w(xin, xout) =(cid:81)\n\nx\u2212i\n\nk\u2208ne(\u03c8) mk\u2192\u03c8(xk)/q(xin). Therefore\n\nm\u03c8\u2192i(xi) (cid:39) proj\n\ns w(xs\n\nin, xs\ns w(xs\n\nout)\u03b4(xi)\nin, xs\n\nout)\n\n/mi\u2192\u03c8(xi),\n\n(4)\n\nin and xs\n\nout are samples from v(xin, xout). To sample from v, we \ufb01rst draw values xs\nwhere xs\nin from q\nthen pass them through the forward-sampling procedure de\ufb01ned by \u03c8 to get a value for xs\nout.\nCrucially, note that we require no knowledge of \u03c8 other than the ability to sample from \u03c8(xout|xin).\nThis allows the model designer to incorporate arbitrary factors simply by providing an implemen-\ntation of this forward sampler, which could be anything from a single line of deterministic code to\na large stochastic image renderer. However, drawing a single sample from \u03c8 can itself be a time-\nconsuming operation, and the complexity of \u03c8 and the arity of xin can both have a dramatic effect\non the number of samples required to compute messages accurately.\n\n3 Just-in-time learning of message mappings\n\n(cid:20)(cid:80)\n\n(cid:80)\n\n(cid:21)\n\nMonte Carlo methods (as de\ufb01ned above) are computationally expensive and can lead to slow infer-\nence. In this paper, we adopt an approach in which we learn a direct mapping, parameterized by \u03b8,\nfrom variable-to-factor messages {mk\u2192\u03c8}k\u2208ne(\u03c8) to a factor-to-variable message m\u03c8\u2192i:\n\nm\u03c8\u2192i(xi) \u2261 f ({mk\u2192\u03c8}k\u2208ne(\u03c8)|\u03b8).\n\n(5)\nUsing this direct mapping function f, factor-to-variable messages can be computed in a fraction\nof the time required to perform full Monte Carlo estimation. Heess et al. [7] recently used neural\nnetworks to learn this mapping of\ufb02ine for a broad range of input message combinations.\nMotivated by the observation that the distribution of input messages that a factor sees is often prob-\nlem speci\ufb01c (Fig. 1a), we consider learning the direct mapping just-in-time in the context of a spe-\nci\ufb01c model. For this we employ \u2018uncertainty aware\u2019 regressors. Along with each prediction m, the\nregressor produces a scalar measure u of its uncertainty about that prediction:\n\n(6)\nWe adopt a framework similar to that of uncertainty sampling [8] (also [9]) and use these uncertain-\nties at run-time to choose between the regressor\u2019s estimate and slower \u2018oracle\u2019 computations:\n\nu\u03c8\u2192i \u2261 u({mk\u2192\u03c8}k\u2208ne(\u03c8)|\u03b8).\n\nm\u03c8\u2192i(xi) =\n\nm\u03c8\u2192i(xi) u\u03c8\u2192i < umax\nmoracle\n\n\u03c8\u2192i (xi) otherwise\n\n(7)\n\nwhere umax is the maximum tolerated uncertainty for a prediction. In this paper we consider impor-\ntance sampling or hand-implemented Infer.NET operators as oracles however other methods such as\nMCMC-based samplers could be used. The regressor is updated after every oracle consultation in\norder to incorporate the newly acquired information.\nAn appropriate value for umax can be found by collecting a small number of Monte Carlo mes-\nsages for the target model of\ufb02ine: the uncertainty aware regressor is trained on some portion of the\ncollected messages, and evaluated on the held out portion, producing predictions m\u03c8\u2192i and con\ufb01-\ndences u\u03c8\u2192i for every held out message. We then set umax such that no held out prediction has an\nerror above a user-speci\ufb01ed, problem-speci\ufb01c maximum tolerated value Dmax.\nA natural choice for this error measure is mean squared error of the parameters of the messages (e.g.\nnatural parameters for the exponential family), however this is sensitive to the particular parameteri-\nzation chosen for the target distribution type. Instead, for each pair of predicted and oracle messages\n\n3\n\n(cid:40)\n\n\ffrom factor \u03c8 to variable i, we calculate the marginals bi and boracle\nrandom variable, and compute the Kullback-Leibler (KL) divergence between the two:\n\u03c8\u2192i ) \u2261 DKL(bi(cid:107)boracle\n\n(8)\n\u03c8\u2192i \u00b7 mi\u2192\u03c8, using the fact that beliefs can be computed\nwhere bi = m\u03c8\u2192i \u00b7 mi\u2192\u03c8 and boracle\nas the product of incoming and outgoing messages on any edge. We refer to the error measure Dmar\nKL\nas marginal KL and use it throughout the JIT framework, as it encourages the system to focus efforts\non the quantity that is ultimately of interest: the accuracy of the posterior marginals.\n\nKL (m\u03c8\u2192i(cid:107)moracle\nDmar\n\nthey each induce on the target\n\n= moracle\n\n),\n\ni\n\ni\n\ni\n\n4 Random decision forests for JIT learning\nWe wish to learn a mapping from a set of incoming messages {mk\u2192\u03c8}k\u2208ne(\u03c8) to the outgoing\nmessage m\u03c8\u2192i. Note that separate regressors are trained for each outgoing message. We require\nthat the regressor: 1) trains and predicts ef\ufb01ciently, 2) can model arbitrarily complex mappings,\n3) can adapt dynamically, and 4) produces uncertainty estimates. Here we describe how decision\nforests can be modi\ufb01ed to satisfy these requirements. For a review of decision forests see [4].\nIn EP, each incoming and outgoing message can be represented using only a few numbers, e.g. a\nGaussian message can be represented by its natural parameters. We refer to the outgoing message by\nmout and to the set of incoming messages by min. Each set of incoming messages min is represented\nin two ways: the \ufb01rst, a concatenation of the parameters of its constituent messages which we call the\n\u2018regression parameterization\u2019 and denote by rin; and the second, a vector of features computed on the\nset which we call the \u2018tree parameterization\u2019 and denote by tin. This tree parametrization typically\ncontains values for a larger number of properties of each constituent message (e.g. parameters and\nmoments), and also properties of the set as a whole (e.g. \u03c8 evaluated at the mode of min). We\nrepresent the outgoing message mout by a vector of real valued numbers rout. Note that din and dout,\nthe number of elements in rin and rout respectively, need not be equal.\nWeak learner model. Data arriving at a split node j is separated into the node\u2019s two children\naccording to a binary weak learner h(tin, \u03c4 j) \u2208 {0, 1}, where \u03c4 j parameterizes the split criterion.\nWe use weak learners of the generic oriented hyperplane type throughout (see [4] for details).\nPrediction model. Each leaf node is associated with a subset of the labelled training data. During\ntesting, a previously unseen set of incoming messages traverses the tree until it reaches a leaf which\nby construction is likely to contain similar training examples. We therefore use the statistics of the\ndata gathered in that leaf to predict outgoing messages with a multivariate polynomial regression\n) + \u0001, where \u03c6n(\u00b7) is the n-th degree polynomial basis\nmodel of the form: rtrain\nfunction, and \u0001 is the dout-dimensional vector of normal error terms. We use the learned dout \u00d7 din-\ndimensional matrix of coef\ufb01cients W at test time to make predictions rout for each rin. To recap, tin\nis used to traverse message sets down to leaves, and rin is used by the linear regressor to predict rout.\nTraining objective function. The optimization of the split functions proceeds in a greedy man-\nner. At each node j, depending on the subset of the incoming training set Sj we learn the\nfunction that \u2018best\u2019 splits Sj into the training sets corresponding to each child, SL\nj , i.e.\n\u03c4 j = argmax\u03c4\u2208Tj I(Sj, \u03c4 ). This optimization is performed as a search over a discrete set Tj of a\nrandom sample of possible parameter settings. The number of elements in Tj is typically kept small,\nintroducing random variation in the different trees in the forest. The objective function I is:\n\nout = W \u00b7 \u03c6n(rtrain\n\nj and SR\n\nin\n\nI(Sj, \u03c4 ) = \u2212E(SL\n\nj , WL) \u2212 E(SR\n\n(9)\nwhere WL and WR are the parameters of the polynomial regression models corresponding to the\nleft and right training sets SL\nE(S, W) =\n\nj , and the \u2018\ufb01t residual\u2019 E is:\n\n(cid:88)\nj and SR\n\nDmar\n\n(10)\n\nmin ) + Dmar\n\nKL (moracle\n\nmin (cid:107)mW\nmin ).\n\nmin(cid:107)moracle\n\nKL (mW\n\nj , WR),\n\nHere min is a set of incoming messages in S, moracle\nmin is the\nestimate produced by the regression model speci\ufb01ed by W and Dmar\nKL is the marginal KL. In simple\nterms, this objective function splits the training data at each node in a way that the relationship\nbetween the incoming and outgoing messages is well captured by the polynomial regression in each\nchild, as measured by symmetrized marginal KL.\n\nis the oracle outgoing message, mW\n\nmin\n\n1\n2\n\nmin\u2208S\n\n4\n\n\fout}, m) where U ({mt\n\nt DKL(mt\n\nout(cid:107)m).\n\nout}, m) =(cid:80)\n\nout of the predicted outgoing messages mt\n\nInstead, we compute the moment average mout of the distributions {mt\n\nEnsemble model. A key aspect of forests is that their trees are randomly different from each other.\nThis is due to the relatively small number of weak learner candidates considered in the optimization\nof the weak learners. During testing, each test point min simultaneously traverses all trees from\ntheir roots until it reaches their leaves. Combining the predictions into a single forest prediction\nmay be done by averaging the parameters rt\nout by each\ntree t, however again this would be sensitive to the parameterizations of the output distribution\nout} by averaging\ntypes.\nthe \ufb01rst few moments of each predicted distribution across trees, and solving for the distribution\nparameters which match the averaged moments. Grosse et al. [6] study the characteristics of the\nmoment average in detail, and have showed that it can be interpreted as minimizing an objective\nfunction mout = argminm U ({mt\nIntuitively, the level of agreement between the predictions of the different trees can be used as a\nproxy of the forest\u2019s uncertainty about that prediction (we choose not to use uncertainty within\nleaves in order to maintain high prediction speed). If all the trees in the forest predict the same output\ndistribution, it means that their knowledge about the function f is similar despite the randomness in\ntheir structures. We therefore set uout \u2261 U ({mt\nout}, mout). A similar notion is used for classi\ufb01cation\nforests, where the entropy of the aggregate output histogram is used as a proxy of the classi\ufb01cation\u2019s\nuncertainty [4]. We illustrate how this idea extends to simple regression forests in Fig. 1b, and in\nSec. 6 we also show empirically that this uncertainty measure works well in practice.\nOnline training. During learning, the trees periodically obtain new information in the form of\n) pairs. The forest makes use of this by pushing min down a portion 0 < \u03c1 \u2264 1 of the\n(min, moracle\ntrees to their leaf nodes and retraining the regressors at those leaves. Typically \u03c1 = 1, however we\nuse values smaller than 1 when the trees are shallow (due to the mapping function being captured\nwell by the regressors at the leaves) and the forest\u2019s randomness is too low to produce reliable\nuncertainty estimates. If the regressor\u2019s \ufb01t residual E at a leaf (Eq. 10) is above a user-speci\ufb01ed\nthreshold value Emax\n\nleaf , a split is triggered on that node. Note that no depth limit is ever speci\ufb01ed.\n\nout\n\n5 Related work\n\nThere are a number of works in the literature that consider using regressors to speed up general\npurpose inference algorithms. For example, the Inverse MCMC algorithm [20] uses discriminative\nestimates of local conditional distributions to make proposals for a Metropolis-Hastings sampler,\nhowever these predictors are not aware of their own uncertainty. Therefore the decision of when the\nsampler can start to rely on them needs to be made manually and the user has to explicitly separate\nof\ufb02ine training and test-time inference computations.\nA related line of work is that of inference machines [14, 15, 17, 13]. Here, message-passing is\nperformed by a sequence of predictions, where the sequence itself is de\ufb01ned by the graphical model.\nThe predictors are jointly trained to ensure that the system produces correct labellings, however the\nresulting inference procedure no longer corresponds to the original (or perhaps to any) graphical\nmodel and therefore the method is unsuitable if we care about querying the model\u2019s latent variables.\nThe closest work to ours is [7], in which Heess et al. use neural networks to learn to pass EP\nmessages. However, their method requires the user to anticipate the set of messages that will ever be\nsent by the factor ahead of time (itself a highly non-trivial task), and it has no notion of con\ufb01dence in\nits predictions and therefore it will silently fail when it sees unfamiliar input messages. In contrast\nthe JIT learner trains in the context of a speci\ufb01c model thereby allocating resources more ef\ufb01ciently,\nand because it knows what it knows, it buys generality without having to do extensive pre-training.\n\n6 Experiments\n\nWe \ufb01rst analyze the behaviour of JIT learning with diagnostic experiments on two factors: logistic\nand compound gamma, which were also considered by [7]. We then demonstrate its application to\na challenging model of US corn yield data. The experiments were performed using the extensible\nfactor API in Infer.NET [12]. Unless stated otherwise, we use default Infer.NET settings (e.g. for\nmessage schedules and other factor implementations). We set the number of trees in each forest to\n64 and use quadratic regressors. Message parameterizations and graphical models, experiments on\na product factor and a quantitative comparison with [7] can be found in the supplementary material.\n\n5\n\n\f(a) Inference error\n\n(b) Worst predicted messages\n\n(c) Awareness of uncertainty\n\nFigure 2: Uncertainty aware regression. All plots for the Gaussian forest.\n(a) Histogram of\nmarginal KLs of outgoing messages, which are typically very small. (b) The forest\u2019s most inaccurate\npredictions (black: moracle, red: m, dashed black: boracle, purple: b). (c) The regressor\u2019s uncertainty\nincreases in tandem with marginal KL, i.e. it does not make con\ufb01dent but inaccurate predictions.\n\n(a) Oracle consultation rate\n\n(b) Inference time\n\n(c) Inference error\n\nFigure 3: Logistic JIT learning. (a) The factor consults the oracle for only a fraction of messages,\n(b) leading to signi\ufb01cant savings in time, (c) whilst maintaining (or even decreasing) inference error.\n\nLogistic. We have access to a hand-crafted EP implementation of this factor, allowing us to perform\nquantitative analysis of the JIT framework\u2019s performance. The logistic deterministically computes\nxout = \u03c3(xin) = 1/(1+exp{\u2212xin}). Sensible choices for the incoming and outgoing message types\nare Gaussian and Beta respectively. We study the logistic factor in the context of Bayesian logistic\nregression models, where the relationship between an input vector x and a binary output observation\ny is modeled as p(y = 1) = \u03c3(wT x). We place zero-mean, unit-variance Gaussian priors on the\nentries of regression parameters w, and run EP inference for 10 iterations.\nWe \ufb01rst demonstrate that the forests described in Sec. 4 are fast and accurate uncertainty aware\nregressors by applying them to \ufb01ve synthetic logistic regression \u2018problems\u2019 as follows: for each\nproblem, we sample a groundtruth w and training xs from N (0, 1) and then sample their corre-\nsponding ys. We use a Bayesian logistic regression model to infer ws using the training datasets\nand make predictions on the test datasets, whilst recording the messages that the factor receives and\nsends during both kinds of inference. We split the observed message sets into training (70%) and\nhold out (30%), and train and evaluate the random forests using the two datasets. In Fig. 2 we show\nthat the regressor is accurate and that it is uncertain whenever it makes predictions with higher error.\nOne useful diagnostic for choosing the various parameters of the forests (including choice of\nparametrization for rin and tin, as well leaf tolerance Emax\nleaf ) is the average utilization of its leaves\nduring held out prediction, i.e. what fraction of leaves are visited at test time. In this experiment the\nforests obtain an average utilization of 1, meaning that every leaf contributes to the predictions of the\n30% held out data, thereby indicating that the forests have learned a highly compact representation\nof the underlying function. As described in Sec. 3, we also use the data gathered in this experiment\nto \ufb01nd an appropriate value of umax for use in just-in-time learning.\nNext we evaluate the uncertainty aware regressor in the context of JIT learning. We present several\nrelated regression problems to a JIT logistic factor, i.e. we keep w \ufb01xed and generate multiple new\n{(x, y)} sets. This is a natural setting since often in practice we observe multiple datasets which\nwe believe to have been generated by the same underlying process. For each problem, using the JIT\nfactor we infer the regression weights and make predictions on test inputs, comparing wall-clock\ntime and accuracy with non-JIT implementations of the factor. We consider two kinds of oracles:\n\n6\n\n\u221220\u221218\u221216\u221214\u221212\u221210\u221280510152025Log marginal KLCount\u22121001000.20.40.6Hold out worst 1Groundtruth \u2212 \u00b5: \u22123.4, \u03c32: 6.8Predicted \u2212 \u00b5: \u22123.3, \u03c32: 6.5Log marginal KL: \u22128.2Log uncertainty: \u22127.8\u22121001000.20.40.6Hold out worst 2Groundtruth \u2212 \u00b5: \u22123.4, \u03c32: 6.8Predicted \u2212 \u00b5: \u22123.3, \u03c32: 6.6Log marginal KL: \u22128.6Log uncertainty: \u22128.2\u221225\u221220\u221215\u221210\u22125\u221218\u221216\u221214\u221212\u221210\u22128\u22126Log marginal KLLog uncertainty TrainHold out5010015020025030035040045050000.050.10.150.20.25Problems seenOracle consultation rate Infer.NET + KNNInfer.NET + JITSampling + KNNSampling + JIT501001502002503003504004505006789101112Problems seenLog time (ms) Infer.NETInfer.NET + KNNInfer.NET + JITSamplingSampling + KNNSampling + JIT50100150200250300350400450500\u221218\u221216\u221214\u221212\u221210Problems seenLog KL of inferred weight posterior Infer.NET + KNNInfer.NET + JITSamplingSampling + KNNSampling + JIT\fthose that consult Infer.NET\u2019s message operators and those that use importance sampling (Eq. 4).\nAs a baseline, we also implemented a K-nearest neighbour (KNN) uncertainty aware regressor.\nHere, messages are represented using their natural parameters, the uncertainty associated with each\nprediction is the mean distance from the K-closest points in this space, and the outgoing message\u2019s\nparameters are found by taking the average of the parameters of the K-closest output messages. We\nuse the same procedure as the one described in Sec. 3 to choose umax for KNN.\nWe observe that the JIT factor does indeed learn about the inference problem over time. Fig. 3a\nshows that the rate at which the factor consults the oracle decreases over the course of the experi-\nment, reaching zero at times (i.e. for these problems the factor relies entirely on its predictions). On\naverage, the factor sends 97.7% of its messages without consulting the sampling oracle (a higher rate\nof 99.2% when using Infer.NET as the oracle, due to lack of sampling noise), which leads to several\norders of magnitude savings in inference time (from around 8 minutes for sampling to around 800\nms for sampling + JIT), even increasing the speed of our Infer.NET implementation (from around\n1300 ms to around 800 ms on average, Fig. 3b). Note that the forests are not merely memorising a\nmapping from input to output messages, as evidenced by the difference in the consultation rates of\nJIT and KNN, and that KNN speed deteriorates as the database grows. Surprisingly, we observe that\nthe JIT regressors in fact decrease the KL between the results produced by importance sampling and\nInfer.NET, thereby increasing overall inference accuracy (Fig. 3c, this could be due to the fact that\nthe regressors at the leaves of the forests smooth out the noise of the sampled messages). Reducing\nthe number of importance samples to reach speed parity with JIT drastically degrades the accuracy\nof the outgoing messages, increasing overall log KL error from around \u221211 to around \u22124.\nCompound gamma. The second factor we investigate is the compound gamma factor. The com-\npound gamma construction is used as a heavy-tailed prior over precisions of Gaussian random vari-\nables, where \ufb01rst r2 is drawn from a gamma with rate r1 and shape s1 and the precision of the\nGaussian is set to be a draw from a gamma with rate r2 and shape s2. Here, we have access to\nclosed-form implementations of the two gamma factors in the construction, however we use the JIT\nframework to collapse the two into a single factor for increased speed.\nWe study the compound gamma factor in the context of Gaussian \ufb01tting, where we sample a ran-\ndom number of points from multiple Gaussians with a wide range of precisions, and then infer the\nprecision of the generating Gaussians via Bayesian inference using a compound gamma prior. The\nnumber of samples varies between 10 and 100 and the precision varies between 10\u22124 and 104 in\neach problem. The compound factor learns the message mapping after around 20 problems (see\nFig. 4a). Note that only a single message is sent by the factor in each episode, hence the abrupt drop\nin inference time. This increase in performance comes at negligible loss of accuracy (Figs. 4b, 4c).\nYield. We also consider a more realistic application to scienti\ufb01c modelling. This is an example\nof a scenario for which our framework is particularly suited: scientists often need to build large\nmodels with factors that directly take knowledge about certain components of the problem into\naccount. We use JIT learning to implement a factor that relates agriculture yields to temperature in\nthe context of an ecological climate model. Ecologists have strong empirical beliefs about the form\nof the relationship between temperature and yield (that yield increases gradually up to some optimal\ntemperature but drops sharply after that point; see Fig 5a and [16, 10]) and it is imperative that this\nrelationship is modelled faithfully. Deriving closed form message-operators is a non-trivial task, and\ntherefore current state-of-the-art is sampling-based (e.g. [3]) and highly computationally intensive.\n\n(a) Inference time\n\n(b) Inference error\n\n(c) Accuracy (1 dot per problem)\n\nFigure 4: Compound gamma JIT learning. (a) JIT reduces inference time for sampling from \u223c11\nseconds to \u223c1 ms. (b) JIT s posteriors agree highly with Infer.NET. Using fewer samples to match\nJIT speed leads to degradation of accuracy. (c) Increased speed comes at negligible loss of accuracy.\n\n7\n\n10203040506070809010002468Problems seenLog time (ms) Infer.NETInfer.NET + KNNInfer.NET + JITSamplingSampling + KNNSampling + JIT00.10.20.30.40.50.60.70.800.20.40.60.81Distance d of inferred log precision from groundtruthRatio of inferred precisions with error < d Infer.NETInfer.NET + JITSamplingSampling (matching JIT speed)Sampling + JIT\u221210\u221250510\u221210\u221250510Sampling inferred log precisionSampling + JIT inferred log precision\f(cid:122)\n\n(cid:125)(cid:124)\n\n2011\n\n(cid:123)\n\n(cid:122)\n\n(cid:125)(cid:124)\n\n2012\n\n(cid:123)\n\n(cid:122)\n\n(cid:125)(cid:124)\n\n2013\n\n(cid:123)\n\n(a) The yield factor\n\n(b) Oracle consultation rate\n\n(c) Accuracy (1 dot per county)\n\nFigure 5: A probabilistic model of corn yield. (a) Ecologists believe that yield increases gradually\nup to some optimal temperature but drops sharply after that point [16, 10], and they wish to incor-\nporate this knowledge into their models faithfully. (b) Average consultation rate per 1,000 messages\nover the course of inference on the three datasets. Notice decrease within and across datasets. (c)\nSigni\ufb01cant savings in inference time (Table 1) come at a small cost in inference accuracy.\n\nWe obtain yield data for 10% of US counties for 2011\u20132013 from the USDA National Agricultural\nStatistics Service [1] and corresponding temperature data using [18]. We \ufb01rst demonstrate that it\nis possible to perform inference in a large-scale ecological model of this kind with EP (graphical\nmodel shown in Fig. 1c; derived in collaboration with computational ecologists; see supplementary\nmaterial for a description), using importance sampling to compute messages for the yield factor\nfor which we lack message-passing operators. In addition to the dif\ufb01culty of computing messages\nfor the multidimensional yield factor, inference in the model is challenging as it includes multiple\nGaussian processes, separate topt and ymax variables for each location, many copies of the yield\nfactor, and its graph is loopy. Results of inference are shown in the supplementary material.\nWe \ufb01nd that with around 100,000 samples the message for the yield factor can be computed ac-\ncurately, making these by far the slowest computations in the inference procedure. We apply JIT\nlearning by regressing these messages instead. The high arity of the factor makes the task particu-\nlarly challenging as it increases the complexity of the mapping function being learned. Despite this,\nwe \ufb01nd that when performing inference on the 2011 data the factor can learn to accurately send up\nto 54% of messages without having to consult the oracle, resulting in a speedup of 195%.\nA common scenario is one in which we collect more data and\nwish to repeat inference. We use the forests learned at the\nend of inference on 2011 data to perform inference on 2012\ndata, and the forests learned at the end of this to do inference\non 2013 data, and compare to JIT learning from scratch for\neach dataset. The factor transfers its knowledge across the\nproblems, increasing inference speedup from 195% to 289%\nand 317% in the latter two experiments respectively (Table 1),\nwhilst maintaining overall inference accuracy (Fig. 5c).\n7 Discussion\nThe success of JIT learning depends heavily on the accuracy of the regressor and its knowledge\nabout its uncertainty. Random forests have shown to be adequate however alternatives may exist,\nand a more sophisticated estimate of uncertainty (e.g. using Gaussian processes) is likely to lead to\nan increased rate of learning. A second critical ingredient is an appropriate choice of umax, which\ncurrently requires a certain amount of manual tuning.\nIn this paper we showed that it is possible to speed up inference by combining EP, importance\nsampling and JIT learning, however it will be of interest to study other inference settings where JIT\nideas might be applicable. Surprisingly, our experiments also showed that JIT learning can increase\nthe accuracy of sampling or accelerate hand-coded message operators, suggesting that it will be\nfruitful to use JIT to remove bottlenecks even in existing, optimized inference code.\n\n11 451s 54% 195% \u2014 \u2014\n12 449s 54% 192% 60% 288%\n13 451s 54% 191% 64% 318%\n\nTable 1: FR is fraction of regres-\nsions with no oracle consultation.\n\nJIT fresh\n\nIS\nJIT continued\nTime FR Speedup FR Speedup\n\nAcknowledgments\n\nThanks to Tom Minka and Alex Spengler for valuable discussions, and to Silvia Caldararu and Drew\nPurves for introducing us to the corn yield datasets and models.\n\n8\n\n0510tOpt2025303540050100150200yMaxYield (bushels/acre)Temperature (celcius)02000400060008000100000.20.30.40.50.60.70.80.9Message numberOracle consultation rate\u221260\u221240\u22122002040\u221260\u221240\u22122002040Sampling inferred county ability (ai)Sampling + JIT inferred county ability (ai)\fReferences\n[1] National Agricultural Statistics Service, 2013. United States Department of Agriculture.\n\nhttp://quickstats.nass.usda.gov/.\n\n[2] Simon Barthelm\u00b4e and Nicolas Chopin. ABC-EP: Expectation Propagation for Likelihood-\nfree Bayesian Computation. In Proceedings of the 28th International Conference on Machine\nLearning, pages 289\u2013296, 2011.\n\n[3] Silvia Caldararu, Vassily Lyutsarev, Christopher McEwan, and Drew Purves.\n\nFilzbach,\n2013. Microsoft Research Cambridge. Website URL: http://research.microsoft.com/en-\nus/projects/\ufb01lzbach/.\n\n[4] Antonio Criminisi and Jamie Shotton. Decision Forests for Computer Vision and Medical\n\nImage Analysis. Springer Publishing Company, Incorporated, 2013.\n\n[5] Noah D. Goodman, Vikash K. Mansinghka, Daniel Roy, Keith Bonawitz, and Joshua B. Tenen-\nbaum. Church: a language for generative models. In Uncertainty in Arti\ufb01cial Intelligence,\n2008.\n\n[6] Roger B Grosse, Chris J Maddison, and Ruslan Salakhutdinov. Annealing between distribu-\ntions by averaging moments. In Advances in Neural Information Processing Systems 26, pages\n2769\u20132777. 2013.\n\n[7] Nicolas Heess, Daniel Tarlow, and John Winn. Learning to Pass Expectation Propagation\nMessages. In Advances in Neural Information Processing Systems 26, pages 3219\u20133227. 2013.\n[8] David D. Lewis and William A. Gale. A Sequential Algorithm for Training Text Classi\ufb01ers.\n\nIn Special Interest Group on Information Retrieval, pages 3\u201312. Springer London, 1994.\n\n[9] Lihong Li, Michael L. Littman, and Thomas J. Walsh. Knows what it knows: a framework for\nself-aware learning. In Proceedings of the 25th International Conference on Machine learning,\npages 568\u2013575, New York, NY, USA, 2008. ACM.\n\n[10] David B. Lobell, Marianne Banziger, Cosmos Magorokosho, and Bindiganavile Vivek. Non-\nlinear heat effects on African maize as evidenced by historical yield trials. Nature Climate\nChange, 1:42\u201345, 2011.\n\n[11] Thomas Minka. Expectation Propagation for approximate Bayesian inference. PhD thesis,\n\nMassachusetts Institute of Technology, 2001.\n\n[12] Thomas Minka, John Winn, John Guiver, and David Knowles. Infer.NET 2.5, 2012. Microsoft\n\nResearch Cambridge. Website URL: http://research.microsoft.com/infernet.\n\n[13] Daniel Munoz. Inference Machines: Parsing Scenes via Iterated Predictions. PhD thesis, The\n\nRobotics Institute, Carnegie Mellon University, June 2013.\n\n[14] Daniel Munoz, J. Andrew Bagnell, and Martial Hebert. Stacked Hierarchical Labeling.\n\nEuropean Conference on Computer Vision, 2010.\n\nIn\n\n[15] Stephane Ross, Daniel Munoz, Martial Hebert, and J. Andrew Bagnell. Learning Message-\nPassing Inference Machines for Structured Prediction. In Conference on Computer Vision and\nPattern Recognition, 2011.\n\n[16] Wolfram Schlenker and Michael J. Roberts. Nonlinear temperature effects indicate severe\ndamages to U.S. crop yields under climate change. Proceedings of the National Academy of\nSciences, 106(37):15594\u201315598, 2009.\n\n[17] Roman Shapovalov, Dmitry Vetrov, and Pushmeet Kohli. Spatial Inference Machines.\n\nConference on Computer Vision and Pattern Recognition, pages 2985\u20132992, 2013.\n\nIn\n\n[18] Matthew J. Smith, Paul I. Palmer, Drew W. Purves, Mark C. Vanderwel, Vassily Lyutsarev,\nBen Calderhead, Lucas N. Joppa, Christopher M. Bishop, and Stephen Emmott. Changing\nhow Earth System Modelling is done to provide more useful information for decision making,\nscience and society. Bulletin of the American Meteorological Society, 2014.\n\n[19] Stan Development Team. Stan: A C++ Library for Probability and Sampling, 2014.\n[20] Andreas Stuhlm\u00a8uller, Jessica Taylor, and Noah D. Goodman. Learning Stochastic Inverses. In\n\nAdvances in Neural Information Processing Systems 27, 2013.\n\n9\n\n\f", "award": [], "sourceid": 128, "authors": [{"given_name": "S. M. Ali", "family_name": "Eslami", "institution": "Microsoft Research"}, {"given_name": "Daniel", "family_name": "Tarlow", "institution": "Microsoft Research"}, {"given_name": "Pushmeet", "family_name": "Kohli", "institution": "Microsoft Research"}, {"given_name": "John", "family_name": "Winn", "institution": "Microsoft Research"}]}