{"title": "Memory-Efficient Backpropagation Through Time", "book": "Advances in Neural Information Processing Systems", "page_first": 4125, "page_last": 4133, "abstract": "We propose a novel approach to reduce memory consumption of the backpropagation through time (BPTT) algorithm when training recurrent neural networks (RNNs). Our approach uses dynamic programming to balance a trade-off between caching of intermediate results and recomputation. The algorithm is capable of tightly fitting within almost any user-set memory budget while finding an optimal execution policy minimizing the computational cost. Computational devices have limited memory capacity and maximizing a computational performance given a fixed memory budget is a practical use-case. We provide asymptotic computational upper bounds for various regimes. The algorithm is particularly effective for long sequences. For sequences of length 1000, our algorithm saves 95\\% of memory usage while using only one third more time per iteration than the standard BPTT.", "full_text": "Memory-Ef\ufb01cient Backpropagation Through Time\n\nAudr\u00afunas Gruslys\nGoogle DeepMind\n\naudrunas@google.com\n\nR\u00e9mi Munos\n\nGoogle DeepMind\n\nmunos@google.com\n\nIvo Danihelka\n\nGoogle DeepMind\n\ndanihelka@google.com\n\nMarc Lanctot\n\nGoogle DeepMind\n\nlanctot@google.com\n\nAlex Graves\n\nGoogle DeepMind\n\ngravesa@google.com\n\nAbstract\n\nWe propose a novel approach to reduce memory consumption of the backpropa-\ngation through time (BPTT) algorithm when training recurrent neural networks\n(RNNs). Our approach uses dynamic programming to balance a trade-off between\ncaching of intermediate results and recomputation. The algorithm is capable of\ntightly \ufb01tting within almost any user-set memory budget while \ufb01nding an optimal\nexecution policy minimizing the computational cost. Computational devices have\nlimited memory capacity and maximizing a computational performance given a\n\ufb01xed memory budget is a practical use-case. We provide asymptotic computational\nupper bounds for various regimes. The algorithm is particularly effective for long\nsequences. For sequences of length 1000, our algorithm saves 95% of memory\nusage while using only one third more time per iteration than the standard BPTT.\n\n1\n\nIntroduction\n\nRecurrent neural networks (RNNs) are arti\ufb01cial neural networks where connections between units\ncan form cycles. They are often used for sequence mapping problems, as they can propagate hidden\nstate information from early parts of the sequence back to later points. LSTM [9] in particular\nis an RNN architecture that has excelled in sequence generation [3, 13, 4], speech recognition\n[5] and reinforcement learning [12, 10] settings. Other successful RNN architectures include the\ndifferentiable neural computer (DNC) [6], DRAW network [8], and Neural Transducers [7].\nBackpropagation Through Time algorithm (BPTT) [11, 14] is typically used to obtain gradients\nduring training. One important problem is the large memory consumption required by the BPTT.\nThis is especially troublesome when using Graphics Processing Units (GPUs) due to the limitations\nof GPU memory.\nMemory budget is typically known in advance. Our algorithm balances the tradeoff between memo-\nrization and recomputation by \ufb01nding an optimal memory usage policy which minimizes the total\ncomputational cost for any \ufb01xed memory budget. The algorithm exploits the fact that the same\nmemory slots may be reused multiple times. The idea to use dynamic programming to \ufb01nd a provably\noptimal policy is the main contribution of this paper.\nOur approach is largely architecture agnostic and works with most recurrent neural networks. Being\nable to \ufb01t within limited memory devices such as GPUs will typically compensate for any increase in\ncomputational cost.\n\n2 Background and related work\n\nIn this section, we describe the key terms and relevant previous work for memory-saving in RNNs.\n\n30th Conference on Neural Information Processing Systems (NIPS 2016), Barcelona, Spain.\n\n\fDe\ufb01nition 1. An RNN core is a feed-forward neural network which is cloned (unfolded in time)\nrepeatedly, where each clone represents a particular time point in the recurrence.\n\nFor example, if an RNN has a single hidden layer whose outputs feed back into the same hidden\nlayer, then for a sequence length of t the unfolded network is feed-forward and contains t RNN cores.\nDe\ufb01nition 2. The hidden state of the recurrent network is the part of the output of the RNN core\nwhich is passed into the next RNN core as an input.\n\nIn addition to the initial hidden state, there exists a single hidden state per time step once the network\nis unfolded.\nDe\ufb01nition 3. The internal state of the RNN core for a given time-point is all the necessary informa-\ntion required to backpropagate gradients over that time step once an input vector, a gradient with\nrespect to the output vector, and a gradient with respect to the output hidden state is supplied. We\nde\ufb01ne it to also include an output hidden state.\n\nAn internal state can be (re)evaluated by executing a single forward operation taking the previous\nhidden state and the respective entry of an input sequence as an input. For most network architectures,\nthe internal state of the RNN core will include a hidden input state, as this is normally required to\nevaluate gradients. This particular choice of the de\ufb01nition will be useful later in the paper.\nDe\ufb01nition 4. A memory slot is a unit of memory which is capable of storing a single hidden state\nor a single internal state (depending on the context).\n\n2.1 Backpropagation through Time\n\nBackpropagation through Time (BPTT) [11, 14] is one of the commonly used techniques to train\nrecurrent networks. BPTT \u201cunfolds\u201d the neural network in time by creating several copies of the\nrecurrent units which can then be treated like a (deep) feed-forward network with tied weights. Once\nthis is done, a standard forward-propagation technique can be used to evaluate network \ufb01tness over\nthe whole sequence of inputs, while a standard backpropagation algorithm can be used to evaluate\npartial derivatives of the loss criteria with respect to all network parameters. This approach, while\nbeing computationally ef\ufb01cient is also fairly intensive in memory usage. This is because the standard\nversion of the algorithm effectively requires storing internal states of the unfolded network core at\nevery time-step in order to be able to evaluate correct partial derivatives.\n\n2.2 Trading memory for computation time\n\nThe general idea of trading computation time and memory consumption in general computation\ngraphs has been investigated in the automatic differentiation community [2]. Recently, the rise of\ndeep architectures and recurrent networks has increased interest in a less general case where the\ngraph of forward computation is a chain and gradients have to be chained in a reverse order. This\nsimpli\ufb01cation leads to relatively simple memory-saving strategies and heuristics. In the context of\nBPTT, instead of storing hidden network states, some of the intermediate results can be recomputed\non demand by executing an extra forward operation.\nChen et. al. proposed subdividing the sequence of size t into\nt equal parts and memorizing only\n\u221a\nhidden states between the subsequences and all internal states within each segment [1]. This uses\nt) memory at the cost of making an additional forward pass on average, as once the errors are\nO(\nbackpropagated through the right-side of the sequence, the second-last subsequence has to be restored\nby repeating a number of forward operations. We refer to this as Chen\u2019s\nThe authors also suggest applying the same technique recursively several times by sub-dividing the\nsequence into k equal parts and terminating the recursion once the subsequence length becomes less\nthan k. The authors have established that this would lead to memory consumption of O(k logk+1(t))\nand computational complexity of O(t logk(t)). This algorithm has a minimum possible memory\nusage of log2(t) in the case when k = 1. We refer to this as Chen\u2019s recursive algorithm.\n\nt algorithm.\n\n\u221a\n\n\u221a\n\n3 Memory-ef\ufb01cient backpropagation through time\n\nWe \ufb01rst discuss two simple examples: when memory is very scarce, and when it is somewhat limited.\n\n2\n\n\fWhen memory is very scarce, it is straightforward to design a simple but computationally inef\ufb01cient\nalgorithm for backpropagation of errors on RNNs which only uses a constant amount of memory.\nEvery time when the state of the network at time t has to be restored, the algorithm would simply\nre-evaluate the state by forward-propagating inputs starting from the beginning until time t. As\nbackpropagation happens in the reverse temporal order, results from the previous forward steps can\nnot be reused (as there is no memory to store them). This would require repeating t forward steps\nbefore backpropagating gradients one step backwards (we only remember inputs and the initial state).\nThis would produce an algorithm requiring t(t + 1)/2 forward passes to backpropagate errors over t\ntime steps. The algorithm would be O(1) in space and O(t2) in time.\nWhen the memory is somewhat limited (but not very scarce) we may store only hidden RNN states\nat all time points. When errors have to be backpropagated from time t to t \u2212 1, an internal RNN\ncore state can be re-evaluated by executing another forward operation taking the previous hidden\nstate as an input. The backward operation can follow immediately. This approach can lead to fairly\nsigni\ufb01cant memory savings, as typically the recurrent network hidden state is much smaller than an\ninternal state of the network core itself. On the other hand this leads to another forward operation\nbeing executed during the backpropagation stage.\n\n3.1 Backpropagation though time with selective hidden state memorization (BPTT-HSM)\n\nThe idea behind the proposed algorithm is to compromise between two previous extremes. Suppose\nthat we want to forward and backpropagate a sequence of length t, but we are only able to store m\nhidden states in memory at any given time. We may reuse the same memory slots to store different\nhidden states during backpropagation. Also, suppose that we have a single RNN core available for\nthe purposes of intermediate calculations which is able to store a single internal state. De\ufb01ne C(t, m)\nas a computational cost of backpropagation measured in terms of how many forward-operations one\nhas to make in total during forward and backpropagation steps combined when following an optimal\nmemory usage policy minimizing the computational cost. One can easily set the boundary conditions:\n2 t(t + 1) is the cost of the minimal memory approach, while C(t, m) = 2t \u2212 1 for all\nC(t, 1) = 1\nm \u2265 t when memory is plentiful (as shown in Fig. 3 a). Our approach is illustrated in Figure 1. Once\nwe start forward-propagating steps at time t = t0, at any given point y > t0 we can choose to put the\ncurrent hidden state into memory (step 1). This step has the cost of y forward operations. States will\nbe read in the reverse order in which they were written: this allows the algorithm to store states in a\nstack. Once the state is put into memory at time y = D(t, m), we can reduce the problem into two\nparts by using a divide-and-conquer approach: running the same algorithm on the t > y side of the\nsequence while using m \u2212 1 of the remaining memory slots at the cost of C(t \u2212 y, m \u2212 1) (step 2),\nand then reusing m memory slots when backpropagating on the t \u2264 y side at the cost of C(y, m)\n(step 3). We use a full size m memory capacity when performing step 3 because we could release the\nhidden state y immediately after \ufb01nishing step 2.\n\nFigure 1: The proposed divide-and-conquer approach.\n\nThe base case for the recurrent algorithm is simply a sequence of length t = 1 when forward and\nbackward propagation may be done trivially on a single available RNN network core. This step has\nthe cost C(1, m) = 1.\n\n3\n\nStep 3: cost = C(y, m)12y...y+1...tStep 1: cost = yy+1...tStep 2: cost = C(t-y, m-1)12y...Hidden state is propagatedGradients get back-propagatedHidden state stored in memoryInternal state of RNN core at time ttRecursive application of the algorithmHidden state is read from memoryHidden state is saved in memoryHidden state is removed from memoryA single forward operationA single backward operationLegend\f(a) Theoretical computational cost\nmeasured in number of forward op-\nerations per time step.\n\n(b) Measured computational cost in\nmiliseconds.\n\nFigure 2: Computational cost per time-step when the algorithm is allowed to remember 10 (red), 50\n(green), 100 (blue), 500 (violet), 1000 (cyan) hidden states. The grey line shows the performance\nof standard BPTT without memory constraints; (b) also includes a large constant value caused by a\nsingle backwards step per time step which was excluded from the theoretical computation, which\nvalue makes a relative performance loss much less severe in practice than in theory.\n\nHaving established the protocol we may \ufb01nd an optimal policy D(t, m). De\ufb01ne the cost of choosing\nthe \ufb01rst state to be pushed at position y and later following the optimal policy as:\n\nQ(t, m, y) = y + C(t \u2212 y, m \u2212 1) + C(y, m)\n\nC(t, m) = Q(t, m, D(t, m))\n\n(2)\n\nD(t, m) = argmin\n1\u2264y1\u03b1 \u2212 1y=1\u03b2)\n\n(10)\n1 is an indicator function. Equations for H(t, m), Di(t, m) and C(t, m) are identical to (8) and (9).\n\n3.5 Analytical upper bound for BPTT-HSM\nWe have established a theoretical upper bound for BPTT-HSM algorithm as C(t, m) \u2264 mt1+ 1\nm . As\nthe bound is not tight for short sequences, it was also numerically veri\ufb01ed that C(t, m) < 4t1+ 1\nm for\nt < 105 and m < 103, or less than 3t1+ 1\nm if the initial forward pass is excluded. In addition to that,\nwe have established a different bound in the regime where t < mm\nm! . For any integer value a and for\nthe computational cost is bounded by C(t, m) \u2264 (a + 1)t. The proofs are given in the\nall t < ma\na!\nsupplementary material. Please refer to supplementary material for discussion on the upper bounds\nfor BPTT-MSM and BPTT-ISM.\n\n3.6 Comparison of the three different strategies\n\n(a) Using 10\u03b1 memory\n\n(b) Using 20\u03b1 memory\n\nFigure 6: Comparison of three strategies in the case when a size of an internal RNN core state is\n\u03b1 = 5 times larger than that of the hidden state, and the total memory capacity allows us remember\neither 10 internal RNN states, or 50 hidden states or any arbitrary mixture of those in the left plot\nand (20, 100) respectively in the right plot. The red curve illustrates BPTT-HSM, the green curve\n- BPTT-ISM and the blue curve - BPTT-MSM. Please note that for large sequence lengths the red\ncurve out-performs the green one, and the blue curve outperforms the other two.\n\nComputational costs for each previously described strategy and the results are shown in Figure 6.\nBPTT-MSM outperforms both BPTT-ISM and BPTT-HSM. This is unsurprising, because the search\nspace in that case is a superset of both strategy spaces, while the algorothm \ufb01nds an optimal strategy\nwithin that space. Also, for a \ufb01xed memory capacity, the strategy memorizing only hidden states\noutperforms a strategy memorizing internal RNN core states for long sequences, while the latter\noutperforms the former for relatively short sequences.\n\n4 Discussion\n\nWe used an LSTM mapping 256 inputs to 256 with a batch size of 64 and measured execution time for\na single gradient descent step (forward and backward operation combined) as a function of sequence\nlength (Figure 2(b)). Please note that measured computational time also includes the time taken by\nbackward operations at each time-step which dynamic programming equations did not take into the\naccount. A single backward operation is usually twice as expensive than a forward operation, because\nit involves evaluating gradients both with respect to input data and internal parameters. Still, as the\nnumber of backward operations is constant it has no impact on the optimal strategy.\n\n4.1 Optimality\n\nThe dynamic program \ufb01nds the optimal computational strategy by construction, subject to memory\nconstraints and a fairly general model that we impose. As both strategies proposed by [1] are\n\n7\n\n\fconsistent with all the assumptions that we have made in section 3.4 when applied to RNNs, BPTT-\nMSM is guaranteed to perform at least as well under any memory budget and any sequence length.\nThis is because strategies proposed by [1] can be expressed by providing a (potentially suboptimal)\npolicy Di(t, m), H(t, m) subject to the same equations for Qi(t, m).\n\n4.2 Numerical comparison with Chen\u2019s\n\nt algorithm\n\n\u221a\n\n\u221a\n\u221a\n\n\u221a\n\nt algorithm requires to remember\n\nChen\u2019s\nt internal RNN states (excluding\ninput hidden states), while the recursive approach requires to remember at least log2 t hidden states.\nIn other words, the model does not allow for a \ufb01ne-grained control over memory usage and rather\nsaves some memory. In the meantime our proposed BPTT-MSM can \ufb01t within almost arbitrary\nconstant memory constraints, and this is the main advantage of our algorithm.\n\nt hidden states and\n\nFigure 7: Left: memory consumption divided by\nRight: computational cost per time-step for a \ufb01xed memory consumption of\nand blue curves correspond to \u03b2 = 2, 5, 10 respectively.\n\nt(1 + \u03b2) for a \ufb01xed computational cost C = 2.\nt(1 + \u03b2). Red, green\n\n\u221a\n\n\u221a\n\n\u221a\n\n\u221a\n\n\u221a\n\n\u221a\n\n\u221a\n\nt(1 + \u03b2) =\n\u221a\n\nt hidden states and\n\u221a\n\nThe non-recursive Chen\u2019s\nt approach does not allow to match any particular memory budget\n\u221a\nmaking a like-for-like comparison dif\ufb01cult. Instead of \ufb01xing the memory budge, it is possible to \ufb01x\n\u221a\ncomputational cost at 2 forwards iterations on average to match the cost of the\nt algorithm and\nobserve how much memory would our approach use. Memory usage by the\nt algorithm would\nbe equivalent to saving\nt internal core states. Lets suppose that the internal\nRNN core state is \u03b1 times larger than hidden states. In this case the size of the internal RNN core\nstate excluding the input hidden state is \u03b2 = \u03b1 \u2212 1. This would give a memory usage of Chen\u2019s\nalgorithm as\nt internal states\nwhere input hidden states can be omitted to avoid duplication. Figure 7 illustrates memory usage by\nour algorithm divided by\nt(1 + \u03b2) for a \ufb01xed execution speed of 2 as a function of sequence length\nand for different values of parameter \u03b2. Values lower than 1 indicate memory savings. As it is seen,\nwe can save a signi\ufb01cant amount of memory for the same computational cost.\nt(1 + \u03b2).\nAnother experiment is to measure computational cost for a \ufb01xed memory consumption of\nThe results are shown in Figure 7. Computational cost of 2 corresponds to Chen\u2019s\nt algorithm. This\nillustrates that our approach does not perform signi\ufb01cantly faster (although it does not do any worse).\nThis is because Chen\u2019s\nt strategy is actually near optimal for this particular memory budget. Still,\nas seen from the previous paragraph, this memory budget is already in the regime of diminishing\nreturns and further memory reductions are possible for almost the same computational cost.\n\nt(\u03b1), as it needs to remember\n\nt hidden states and\n\n\u221a\n\n\u221a\n\n\u221a\n\n\u221a\n\n5 Conclusion\n\nIn this paper, we proposed a novel approach for \ufb01nding optimal backpropagation strategies for\nrecurrent neural networks for a \ufb01xed user-de\ufb01ned memory budget. We have demonstrated that the\nmost general of the algorithms is at least as good as many other used common heuristics. The main\nadvantage of our approach is the ability to tightly \ufb01t to almost any user-speci\ufb01ed memory constraints\ngaining maximal computational performance.\n\n8\n\n\fReferences\n[1] Tianqi Chen, Bing Xu, Zhiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear\n\nmemory cost. arXiv preprint arXiv:1604.06174, 2016.\n\n[2] Benjamin Dauvergne and Laurent Hasco\u00ebt. The data-\ufb02ow equations of checkpointing in reverse\nautomatic differentiation. In Computational Science\u2013ICCS 2006, pages 566\u2013573. Springer,\n2006.\n\n[3] Douglas Eck and Juergen Schmidhuber. A \ufb01rst look at music composition using LSTM recurrent\n\nneural networks. Istituto Dalle Molle Di Studi Sull Intelligenza Arti\ufb01ciale, 103, 2002.\n\n[4] Alex Graves. Supervised Sequence Labelling with Recurrent Neural Networks. Studies in\n\nComputational Intelligence. Springer, 2012.\n\n[5] Alex Graves, Abdel-rahman Mohamed, and Geoffrey Hinton. Speech recognition with deep\nrecurrent neural networks. In Acoustics, Speech and Signal Processing (ICASSP), 2013 IEEE\nInternational Conference on, pages 6645\u20136649. IEEE, 2013.\n\n[6] Alex Graves, Greg Wayne, Malcolm Reynolds, Tim Harley, Ivo Danihelka, Agnieszka Grabska-\nBarwi\u00b4nska, Sergio G\u00f3mez Colmenarejo, Edward Grefenstette, Tiago Ramalho, John Agapiou,\nAdri\u00e0 Puigdom\u00e8nech Badia, Karl Moritz Hermann, Yori Zwols, Georg Ostrovski, Adam Cain,\nHelen King, Christopher Summer\ufb01eld, Phil Blunsom, Koray Kavukcuoglu, and Demis Hassabis.\nHybrid computing using a neural network with dynamic external memory. Nature, advance\nonline publication, October 2016.\n\n[7] Edward Grefenstette, Karl Moritz Hermann, Mustafa Suleyman, and Phil Blunsom. Learning to\ntransduce with unbounded memory. In Advances in Neural Information Processing Systems,\npages 1819\u20131827, 2015.\n\n[8] Karol Gregor, Ivo Danihelka, Alex Graves, and Daan Wierstra. DRAW: A recurrent neural\n\nnetwork for image generation. arXiv preprint arXiv:1502.04623, 2015.\n\n[9] Sepp Hochreiter and J\u00fcrgen Schmidhuber. Long short-term memory. Neural computation,\n\n9(8):1735\u20131780, 1997.\n\n[10] Volodymyr Mnih, Adri\u00e0 Puigdom\u00e8nech Badia, Mehdi Mirza, Alex Graves, Timothy P. Lillicrap,\nTim Harley, David Silver, and Koray Kavukcuoglu. Asynchronous methods for deep reinforce-\nment learning. In Proceedings of the 33rd International Conference on Machine Learning\n(ICML), pages 1928\u20131937, 2016.\n\n[11] David E Rumelhart, Geoffrey E Hinton, and Ronald J Williams. Learning internal representa-\n\ntions by error propagation. Technical report, DTIC Document, 1985.\n\n[12] Ivan Sorokin, Alexey Seleznev, Mikhail Pavlov, Aleksandr Fedorov, and Anastasiia Ignateva.\n\nDeep attention recurrent Q-network. arXiv preprint arXiv:1512.01693, 2015.\n\n[13] Ilya Sutskever, James Martens, and Geoffrey E Hinton. Generating text with recurrent neural\nnetworks. In Proceedings of the 28th International Conference on Machine Learning (ICML-11),\npages 1017\u20131024, 2011.\n\n[14] Paul J Werbos. Backpropagation through time: what it does and how to do it. Proceedings of\n\nthe IEEE, 78(10):1550\u20131560, 1990.\n\n9\n\n\f", "award": [], "sourceid": 2047, "authors": [{"given_name": "Audrunas", "family_name": "Gruslys", "institution": "Google DeepMind"}, {"given_name": "Remi", "family_name": "Munos", "institution": "Google DeepMind"}, {"given_name": "Ivo", "family_name": "Danihelka", "institution": "Google DeepMind"}, {"given_name": "Marc", "family_name": "Lanctot", "institution": "Google DeepMind"}, {"given_name": "Alex", "family_name": "Graves", "institution": "Google DeepMind"}]}