{"title": "Communication-efficient Distributed SGD with Sketching", "book": "Advances in Neural Information Processing Systems", "page_first": 13144, "page_last": 13154, "abstract": "Large-scale distributed training of neural networks is often limited by network bandwidth, wherein the communication time overwhelms the local computation time. Motivated by the success of sketching methods in sub-linear/streaming algorithms, we introduce Sketched-SGD, an algorithm for carrying out distributed SGD by communicating sketches instead of full gradients. We show that \\ssgd has favorable convergence rates on several classes of functions. When considering all communication -- both of gradients and of updated model weights -- Sketched-SGD reduces the amount of communication required compared to other gradient compression methods from $\\mathcal{O}(d)$ or $\\mathcal{O}(W)$ to $\\mathcal{O}(\\log d)$, where $d$ is the number of model parameters and $W$ is the number of workers participating in training. We run experiments on a transformer model, an LSTM, and a residual network, demonstrating up to a 40x reduction in total communication cost with no loss in final model performance. We also show experimentally that Sketched-SGD scales to at least 256 workers without increasing communication cost or degrading model performance.", "full_text": "Communication-ef\ufb01cient Distributed SGD with\n\nSketching\n\nNikita Ivkin \u2217\u2020\n\nAmazon\n\nivkin@amazon.com\n\nDaniel Rothchild \u2217\n\nUC Berkeley\n\ndrothchild@berkeley.edu\n\nEnayat Ullah \u2217\n\nJohns Hopkins University\n\nenayat@jhu.edu\n\nVladimir Braverman \u2021\nJohns Hopkins University\n\nvova@cs.jhu.edu\n\nIon Stoica\nUC Berkeley\n\nistoica@berkeley.edu\n\nRaman Arora\n\nJohns Hopkins University\n\narora@cs.jhu.edu\n\nAbstract\n\nLarge-scale distributed training of neural networks is often limited by network band-\nwidth, wherein the communication time overwhelms the local computation time.\nMotivated by the success of sketching methods in sub-linear/streaming algorithms,\nwe introduce SKETCHED-SGD4, an algorithm for carrying out distributed SGD by\ncommunicating sketches instead of full gradients. We show that SKETCHED-SGD\nhas favorable convergence rates on several classes of functions. When considering\nall communication \u2013 both of gradients and of updated model weights \u2013 SKETCHED-\nSGD reduces the amount of communication required compared to other gradient\ncompression methods from O(d) or O(W ) to O(log d), where d is the number\nof model parameters and W is the number of workers participating in training.\nWe run experiments on a transformer model, an LSTM, and a residual network,\ndemonstrating up to a 40x reduction in total communication cost with no loss in\n\ufb01nal model performance. We also show experimentally that SKETCHED-SGD\nscales to at least 256 workers without increasing communication cost or degrading\nmodel performance.\n\n1\n\nIntroduction\n\nModern machine learning training workloads are commonly distributed across many machines using\ndata-parallel synchronous stochastic gradient descent. At each iteration, W worker nodes split a\nmini-batch of size B; each worker computes the gradient of the loss on its portion of the data, and then\na parameter server sums each worker\u2019s gradient to yield the full mini-batch gradient. After using this\ngradient to update the model parameters, the parameter server must send back the updated weights to\neach worker. We emphasize that our method can naturally be extended to other topologies as well\n(e.g. ring, complete, etc.) \u2013 in particular we would then communicate sketches over a minimum\nspanning tree of the communication graph. However, for ease of exposition, in this work we focus\nexclusively on the star topology. For a \ufb01xed batch size B, the amount of data each worker processes\n\u2013 and therefore the amount of computation required \u2013 is inversely proportional to W . On the other\nhand, the amount of communication required per worker is independent of W . Even with optimal\ninterleaving of the communication and computation, the total training time is at least the maximum\n\n\u2217equal contribution\n\u2020This work was done while the author was at Johns Hopkins University.\n\u2021This work was done, in part, while the author was visiting the Simons Institute for the Theory of Computing.\n4Code is available at https://github.com/dhroth/sketchedsgd\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fof the per-worker communication time and per-worker computation time. Increasing the number of\nworkers W therefore yields an increasingly marginal reduction in the training time, despite increasing\nthe overall training cost (number of machines times training time) linearly in W .\nSeveral approaches address this issue by using a large batch size to increase the per-worker computa-\ntion time [You et al., 2017, Goyal et al., 2017]. However, theoretical and empirical evidence both\nsuggest that there is a maximum mini-batch size beyond which the number of iterations required\nto converge stops decreasing, and generalization error begins to increase [Ma et al., 2017, Li et al.,\n2014, Golmant et al., 2018, Shallue et al., 2018, Keskar et al., 2016, Hoffer et al., 2017]. In this\npaper, we aim instead to decrease the communication cost per worker. We use a technique from\nstreaming algorithms called sketching, which allows us to recover favorable convergence guarantees\nof vanilla SGD. In short, our algorithm has workers send gradient sketches of size O(log d) instead\nof the gradients themselves. Although other methods for reducing the communication cost exist, to\nour knowledge ours is the only one that gives a per-worker communication cost that is sub-linear in d\nand constant in W . In practice, we show that our method achieves high compression for large d with\nno loss in model accuracy, and that it scales as expected to large W .\n\n2 Related Work\n\nMost existing methods for reducing communication cost in synchronous data-parallel distributed\nSGD either quantize or sparsify gradients. A number of quantization methods have been proposed.\nThese methods either achieve only a constant reduction in the communication cost per iteration\n[Wen et al., 2017, Bernstein et al., 2018], or achieve an asymptotic reduction in communication cost\nper iteration at the expense of an equal (or greater) asymptotic increase in the number of iterations\nrequired [Alistarh et al., 2017]. Even in the latter case, the total communication required for all of\ntraining sees no asymptotic improvement.\nOther methods sparsify the gradients instead of quantizing each gradient element [Stich et al., 2018,\nAlistarh et al., 2018, Lin et al., 2017]. A popular heuristic is to send the top-k coordinates of the\nlocal worker gradients and then average them to obtain an approximate mini-batch gradient. These\nmethods can achieve good performance in practice, but they suffer from a few drawbacks. They\ncurrently have no convergence guarantees, since the estimated mini-batch gradient can be very far\nfrom the true mini-batch gradient (unless explicitly assumed, as in e.g. Alistarh et al. [2018]), which\nprecludes appealing to any known convergence result. Another drawback is that, although these\nmethods achieve high compression rates when the workers transmit gradients to the parameter server,\nthe return communication of the updated model parameters grows as O(W ): the local top-k of each\nworker may be disjoint, so there can be as many as kW parameters updated each iteration. This\nO(W ) communication cost is not just a technicality, since reducing the back-communication to O(k)\nwould require sparsifying the sum of the local top-k, which could hinder convergence. Because of\nthis scaling, local top-k methods suffer from poor compression in settings with large W .\nFrom another standpoint, all gradient compression techniques yield either biased or unbiased gradient\nestimates. A number of quantization methods are crafted speci\ufb01cally to yield unbiased estimates,\nsuch that the theoretical guarantees of SGD continue to apply [Alistarh et al., 2017, Wen et al., 2017].\nHowever, even without these guarantees, a number of methods using biased gradient estimates were\nalso found to work well in practice [Bernstein et al., 2018, Seide et al., 2014, Strom, 2015]. Recently,\nStich et al. [2018], Karimireddy et al. [2019] gave convergence guarantees for this kind of biased\ncompression algorithm, showing that accumulating compression error locally in the workers can\novercome the bias in the weight updates as long as the compression algorithm obeys certain properties.\nOur method falls into this category, and we prove that compressing gradients with sketches obeys\nthese properties and therefore enjoys the convergence guarantees in Stich et al. [2018]. In effect, we\nintroduce a method that extends the theoretical results of Stich et al. [2018] from a single machine\nto the distributed setting. Concurrently with this work, Koloskova et al. [2019] also introduce a\ndistributed learning algorithm with favorable convergence guarantees, in which workers communicate\ncompressed gradients over an arbitrary network topology.\nPrior work has proposed applying sketching to address the communication bottleneck in distributed\nand Federated Learning [Kone\u02c7cn`y et al., 2016, Jiang et al., 2018]. However, these methods either do\nnot have provable guarantees, or they apply sketches only to portions of the data, failing to alleviate\nthe \u2126(W d) communication overhead. In particular, Kone\u02c7cn`y et al. [2016] propose \u201csketched updates\"\n\n2\n\n\fin Federated Learning for structured problems, and Jiang et al. [2018] introduce a range of hashing\nand quantization techniques to improve the constant in O (W d).\nAnother line of work that we draw from applies sketching techniques to learning tasks where the\nmodel itself cannot \ufb01t in memory [Aghazadeh et al., 2018, Tai et al., 2018]. In our setting, we can\nafford to keep a dense version of the model in memory, and we only make use of the memory-saving\nproperties of sketches to reduce communication between nodes participating in distributed learning.\n\n3 Preliminaries\nSGD. Let w \u2208 Rd be the parameters of the model to be trained and fi(w) be the loss incurred\nby w at the ith data point (xi, yi) \u223c D. The objective is to minimize the generalization error\n(xi,yi)\u223cD [fi(w)]. In large-scale machine learning, this objective is typically minimized\nf (w) =\nusing mini-batch stochastic gradient descent: given a step size \u03b7t, at each iteration, w is updated\nas wt+1 = wt \u2212 \u03b7tgt, where gt = \u2207w\ni\u2208M fi(w) is the gradient of the loss computed on\na minibatch M.\ni.e.\n\n(cid:3) = \u2207f (wt\u22121). As is standard, we further assume that the gt have bounded moment\n(cid:105) \u2264 \u03c32 for constants G\n\n(cid:105) \u2264 G2 and E(cid:104)(cid:107)gt \u2212 \u2207f (wt)(cid:107)2\n\nIf M is randomly selected, then the gradient estimates gt are unbiased:\n\nE(cid:2)gt|{wi}t\u22121\nand variance: E(cid:104)(cid:107)gt(cid:107)2\n\n(cid:80)\n\n2 |{wi}t\u22121\n\n2 |{wi}t\u22121\n\ni=0\n\nE\n\ni=0\n\ni=0\n\n\u00b5T\n\nand \u03c3. We adopt the usual de\ufb01nitions for smooth and strongly convex functions:\nDe\ufb01nition 1 (Smooth strongly convex function). f : Rd \u2192 R is a L-smooth and \u00b5-strongly convex\nif the following hold \u2200 w1, w2 \u2208 Rd,\n1. (cid:107)\u2207f (w2) \u2212 \u2207f (w2)(cid:107) \u2264 L(cid:107)w2 \u2212 w1(cid:107) (Smoothness)\n2. f (w2) \u2265 f (w1) + (cid:104)\u2207f (w1), w2 \u2212 w1(cid:105) + \u00b5\n\nFor smooth strongly convex functions, SGD converges at a rate of O(cid:16) G2L\n\n(cid:17)\n2 (cid:107)w2 \u2212 w1(cid:107)2 (Strong convexity)\n\n[Rakhlin et al., 2012].\nCount Sketch. Our primary interest is in \ufb01nding large coordinates (or \u201cheavy hitters\u201d) of a gradient\nvector g \u2208 Rd. Heavy hitter sketches originated in the streaming model, where the vector g is de\ufb01ned\nby a sequence of updates {(ij, wj)}n\nj=1, such that the j-th update modi\ufb01es the ij-th coordinate of g\nas gij += wj [Charikar et al., 2002, Cormode and Muthukrishnan, 2005, Braverman et al., 2017]. In\nthe streaming model, sketches must use memory sublinear in both d and n.\n\u03b5 log d) using a Count\nIn this work we compress a gradient vector g into a sketch S(g) of size O( 1\nSketch [Charikar et al., 2002]. A Count Sketch S(g) approximates every coordinate of g with an (cid:96)2\ni \u2212 \u03b5(cid:107)g(cid:107)2\n2 \u2264 \u02c6g2\n2. In\nguarantee: it is always possible to recover \u02c6gi from S(g) such that g2\naddition, S(g) can approximate the (cid:96)2 norm of the entire gradient. These two properties let a sketch\n\ufb01nd every (cid:96)2 heavy hitter, i.e. every coordinate i such that g2\n2. With a small enough \u03b5, the\nset of heavy hitters can be used as approximation of top-k largest coordinates of gradient vector g.\nDue to its linearity, the Count Sketch is widely adopted in distributed systems. Consider the case\nof a parameter server and two workers hosting vectors g1 and g2. To reduce communication, both\nworkers can send the parameter server sketches S(g1) and S(g2) instead of the vectors themselves.\nThe parameter server can then merge these sketches as S(g) = S(g1 + g2) = S(g1) + S(g2). This\nlets the parameter server \ufb01nd the approximate top-k largest coordinates in a vector distributed among\nmany workers. We defer a more detailed discussion of the Count Sketch to Appendix C.\n\ni > \u03b5(cid:107)g(cid:107)2\n\ni + \u03b5(cid:107)g(cid:107)2\n\ni \u2264 g2\n\n4 Sketched SGD\n\nIn SKETCHED-SGD, each worker transmits a sketch of its gradient instead of the gradient itself, as\ndescribed above. The parameter server sums the workers\u2019 sketches, and then recovers the largest\ngradient elements by magnitude from the summed sketch. To improve the compression properties of\nsketching, we then perform a second round of communication, in which the parameter server requests\nthe exact values of the top-k, and uses the sum of those in the weight update. This algorithm for\nrecovering top-k elements from a sketch is summarized in Algorithm 1.\nEvery iteration, only k values of each worker\u2019s gradient are included in the \ufb01nal weight update.\nInstead of discarding the remaining d \u2212 k gradient elements, it is important both theoretically and\n\n3\n\n\fempirically to accumulate these elements in local error accumulation vectors, which are then added to\nthe next iteration\u2019s gradient [Karimireddy et al., 2019, Stich et al., 2018]. This process is summarized\nin Algorithm 2.\n\nAlgorithm 1 HEAVYMIX\nInput: S - sketch of gradient g; k - parameter\n1: Query \u02c6(cid:96)2\n2 from sketch S\n2: \u2200i query \u02c6g2\n2 from sketch S\n\n2 = (1 \u00b1 0.5)(cid:107)g(cid:107)2\n(cid:111)\n2k(cid:107)g(cid:107)2\ni = g2\ni|\u02c6gi \u2265 \u02c6(cid:96)2\n\nand N H \u2190(cid:110)\n\n3: H \u2190(cid:110)\n\ni \u00b1 1\n2/k\n\ni| \u02c6gi < \u02c6(cid:96)2\n4: Topk = H \u222a randl(N H), where l = k \u2212 |H|\n5: second round of communication to get exact values of Topk\nOutput: \u02dcg: \u2200i \u2208 Topk : \u02dcgi = gi and \u2200i /\u2208 Topk : \u02dcgi = 0\n\n2/k\n\n(cid:111)\n\nt=1 qt, a0 = 0\n\nt+\u03be , qt \u2190 (\u03be + t)2, QT =(cid:80)T\n(cid:80)W\n\nAlgorithm 2 SKETCHED-SGD\nInput: k, \u03be, T, W\n1: \u03b7t \u2190 1\n2: for t = 1, 2,\u00b7\u00b7\u00b7 T do\n3:\n4:\n5:\n6:\n7:\n8:\n9:\n10: end for\nOutput: \u02c6wT = 1\nQT\n\nCompute stochastic gradient gi\nt\nError correction: \u00afgi\nt = \u03b7tgi\nCompute sketches Si\nt of \u00afgi\nAggregate sketches St = 1\nW\n\u02dcgt = HEAVYMIX(St, k)\nUpdate wt+1 = wt \u2212 \u02dcgt and send \u02dcgt (which is k-sparse) to Workers\nError accumulation: ai\n\nt + ai\nt and send to Parameter Server\n\n(cid:80)T\n\nt \u2212 \u02dcgt\n\nt=1 qtwt\n\ni=1 Si\n\nt = \u00afgi\n\nt\u22121\n\nt\n\nWorkeri\nWorkeri\nWorkeri\nParameter Server\nParameter Server\nParameter Server\nWorkeri\n\nWe now state convergence results for SKETCHED-SGD. Proofs are deferred to Appendix A.\nTheorem 1 (strongly convex, smooth). Let f : Rd \u2192 R be a L-smooth \u00b5-strongly convex function,\nand let the data be shared among W workers. Given 0 < k \u2264 d, 0 < \u03b1, and\u03b4 < 1, Algorithm 2\nSKETCHED-SGD run with sketch size = O (k log(dT /\u03b4), step size \u03b7t = 1\nt+\u03be , with \u03be > 2 + d(1+\u03b2)\nk(1+\u03c1) ,\nwith \u03b2 > 4 and \u03c1 =\n\n(\u03b2\u22124)(\u03b2+1)2 after T steps outputs \u02c6wT such that the following holds,\n\n1. With probability at least 1 \u2212 \u03b4, E [f ( \u02c6wT )] \u2212 f (w\u2217) \u2264 O(cid:16) \u03c32\n\n\u00b5T + d2G2L\n\nk2\u00b52T 2 + d3G3\n\n(cid:17)\n\nk3\u00b5T 3\n\n4\u03b2\n\n2. The total communication per update is \u0398(k log(dT /\u03b4)W ) bits.\nRemarks\n1. The convergence rate for vanilla SGD is O(1/T ). Therefore, our error is larger the SGD error\n\nwhen T = o((d/k)2), and approaches the SGD error for T = \u2126((d/k)2).\n\n2. Although not stated in this theorem, Stich et al. [2018] show that using the top-k coordinates of the\ntrue mini-batch gradient as the SGD update step yields a convergence rate equivalent to that of\nSKETCHED-SGD. We therefore use this \u201ctrue top-k\u201d method as a baseline for our results.\n\n3. Note that the leading term in the error is O(\u03c32/T ) (as opposed to O(G2/T ) in [Stich et al., 2018]);\nthis implies that in setting where the largest minibatch size allowed is too large to \ufb01t in one machine,\nand going distributed allows us to use larger mini-batches, the variance reduces by a factor W .\nThis reduces the number of iterations required (asymptotically) linearly with W .\n\n4. As is standard, the above high probability bound can be converted to an expectation (over random-\n\nness in sketching) bound; this is stated as Theorem 6 in the Appendix A.\n\n5. The result of [Karimireddy et al., 2019] allows us to extend our theorems to smooth nonconvex\nand non-smooth convex functions; these are presented as Theorems 4 and 5 in the Appendix B..\n\nProof Sketch. The proof consists of two parts. First, we show that SKETCHED-SGD satis\ufb01es the\ncriteria in Stich et al. [2018], from which we obtain a convergence result when running SKETCHED-\nSGD on a single machine. We then use properties of the Count Sketch to extend this result to the\ndistributed setting.\n\n4\n\n\fFor the \ufb01rst part, the key idea is to show that our heavy hitter recovery routine HEAVYMIX satis\ufb01es a\ncontraction property, de\ufb01ned below.\nDe\ufb01nition 2 (\u03c4-contraction [Stich et al., 2018]). A \u03c4-contraction operator is a possibly randomized\n\noperator comp : Rd \u2192 Rd that satis\ufb01es: \u2200x \u2208 Rd, E(cid:104)(cid:107)x \u2212 comp(x)(cid:107)2(cid:105) \u2264 (1 \u2212 \u03c4 )(cid:107)x(cid:107)2\nand bounded as E(cid:104)(cid:107)g(cid:107)2(cid:105) \u2264 G2, choosing the step-size appropriately, Stich et al. [2018] give a\nconvergence rate of O(cid:16) G2\n\nGiven a contraction operator with \u03c4 = k/d, and assuming that the stochastic gradients g are unbiased\n\nfor sparsi\ufb01ed SGD with error accumulation. As\nstated in Lemma 1, HEAVYMIX satis\ufb01es this contraction property, and therefore inherits this\n(single-machine) convergence result:\nLemma 1. HEAVYMIX, with sketch size \u0398(k log(d/\u03b4)) is a k/d-contraction with probability \u2265 1\u2212\u03b4.\nThis completes the \ufb01rst part of the proof. To extend SKETCHED-SGD to the distributed setting,\nwe exploit the fact that Count Sketches are linear, and can approximate (cid:96)2 norms. The full proof is\ndeferred to Appendix A.\n\nk2\u00b52T 2 + d3G3\n\n\u00b5T + d2G2L\n\n(cid:17)\n\nk3\u00b5T 3\n\n5 Empirical Results\n\n5.1 Training Algorithm\n\nIn practice, we modify SKETCHED-SGD in the following ways\n\u2022 We employ momentum when training. Following Lin et al. [2017], we use momentum correc-\ntion and momentum factor masking. Momentum factor masking mitigates the effects of stale\nmomentum, and momentum correction is a way to do error feedback on SGD with momentum\n[Karimireddy et al., 2019].\n\u2022 We use the Count Sketch to identify heavy coordinates, however we perform an additional round\nof communication to collect the exact values of those coordinates. In addition, to identify the top\nk heavy coordinates, we query the Count Sketch, and then each of the workers, for the top P k\nelements instead; this is a common technique used with sketching to improve stability. The total\nresulting communication cost is P k + |S| + k per worker, where |S| is the size of the sketch, and\nthe last k corresponds to the the updated model parameters the parameter server must send back to\nthe workers.\n\u2022 We transmit gradients of the bias terms uncompressed. The number of bias terms in our models is\n\n< 1% of the total number of parameters.\n\nOur emperical training procedure is summarized in Algorithm 3.\n\n0 from the same random seed on each Worker.\n\nAlgorithm 3 EMPIRICAL TRAINING\nInput: k, \u03b7t, m, T\n1: \u2200i : ui, vi \u2190 0\n2: Initialize wi\n3: for t = 1, 2, . . . T do\nCompute stochastic gradient gi\n4:\n5: Momentum: ui \u2190 mui + gi\nt\nError accumulation: vi \u2190 vi + ui\n6:\nCompute sketch Si\n7:\nAggregate sketches St = 1\n8:\nW\nRecover the top-P k coordinates from St: \u02dcgt = topP k(St)\n9:\nQuery all workers for exact values of nonzero elements in \u02dcgt; store the sum in \u02dcgt\n10:\nSend the k-sparse \u02dcgt to Workers\n11:\nupdate wi\n12:\nt (cid:54)= 0\nui, vi \u2190 0, for all i s.t. \u02dcgi\n13:\n14: end for\n\nt of vi and send to Parameter Server\n\nt \u2212 \u03b7t\u02dcgt on each worker\n\n(cid:80)W\n\nt+1 = wi\n\ni=1 Si\n\nt\n\nt\n\nWorkeri\nWorkeri\nWorkeri\nWorkeri\nParameter Server\nParameter Server\nParameter Server\nParameter Server\nWorkeri\nWorkeri\n\n5.2 Sketching Implementation\n\nWe implement a parallelized Count Sketch with PyTorch [Paszke et al., 2017]. The Count Sketch\ndata structure supports a query method, which returns a provable \u00b1\u03b5(cid:107)g(cid:107)2 approximation to each\n\n5\n\n\fFigure 1: Learning curves for a transformer model trained on the WMT 2014 English to German\ntranslation task. All models included here achieve comparable BLEU scores after 60,000 iterations\n(see Table 1). Each run used 4 workers.\n\ncoordinate value. However, to the best of our knowledge, there is no ef\ufb01cient way to \ufb01nd heavy\ncoordinates in the presence of negative inputs. Fortunately, in our application, it is computationally\nef\ufb01cient on the GPU to simply query the sketch for every gradient coordinate, and then choose the\nlargest elements.\n\n5.3 Large d\n\nFirst, we show that SKETCHED-SGD achieves high compression with no loss in accuracy. Because\nthe sketch size grows as O(log d), we expect to see the greatest compression rates for large d.\nAccordingly, we test on a transformer model with 90M parameters, and on a stacked LSTM model\nwith 73M parameters. We train both models on the WMT 2014 English to German translation task,\nand we use code from the OpenNMT project [Klein et al., 2017]. In all cases, the compression factor\nfor SKETCHED-SGD is computed as 2d/(|S| + P k + k), where 2d is the cost to send a (dense)\ngradient and receive a new (dense) parameter vector, |S| is the sketch size, P k is the number of\nelements sent in the second round of communication, and the last k represents the number of modi\ufb01ed\nparameter values that must be sent back to each worker.\nSKETCHED-SGD achieves the same theoretical convergence rate as top-k SGD, in which the\nweight update consists of the top-k elements of the full mini-batch gradient. We therefore perform\nexperiments with SKETCHED-SGD using a value of k that yields good performance for top-k SGD.\nFigure 2 shows top-k results over a range of values of k. Curiously, performance starts to degrade for\nlarge k. Although performance on the training data should in principle strictly improve for larger k,\nsparsifying gradients regularizes the model, so k < d may yield optimal performance on the test set.\nIn addition, we expect performance to degrade on both the training and test sets for large k due to\nmomentum factor masking. To mitigate stale momentum updates, momentum factor masking zeros\nthe velocity vector at the k coordinates that were updated in each iteration. In the limit k = d, this\ncompletely negates the momentum, hindering convergence. For all SKETCHED-SGD experiments\non these two models, we use k = 100, 000, for which top-k SGD yields a BLEU score of 26.65\nfor the transformer and 22.2 for the LSTM. For reference, uncompressed distributed SGD with the\nsame hyperparameters achieves a BLEU of 26.29 for the transformer and 20.87 for the LSTM. Using\nSKETCHED-SGD, we can obtain, with no loss in BLEU, a 40x reduction in the total communication\ncost during training, including the cost to disseminate updated model parameters. See Table 1 for a\nsummary of BLEU results. Compression numbers include both the communication required to send\ngradients as well as the cost to send back the new model parameters. We do not include the cost to\n\n6\n\n0100002000030000400005000060000Iteration101102103104Validation PerplexitySketching, k=100,00040x comp.20x comp.True Top-kUncompressed0.00.51.01.52.0|S|/Pk\f(a) WMT14 Translation Task\n\n(b) CIFAR-10 Classi\ufb01cation Task\n\nFigure 2: True top-k results for a range of k. Left: two models (transformer and LSTM) on the WMT\n2014 English to German translation task. Right: a residual network on the CIFAR-10 classi\ufb01cation\ntask. For the larger models (left), true top-k slightly outperforms the baseline for a range of k. We\nsuspect this is because k-sparsifying gradients serves to regularize the model.\n\nBLEU (transformer) BLEU (LSTM)\n\nUncompressed Distributed SGD\nTop-100, 000 SGD\nSKETCHED-SGD, 20x compression\nSKETCHED-SGD, 40x compression\n\n26.29\n26.65\n26.875\n26.796\n\n20.87\n22.2\n\n\u2013\n\n20.95 7\n\nTable 1: BLEU scores on the test data achieved for uncompressed distributed SGD, top-k SGD, and\nSKETCHED-SGD with 20x and 40x compression. Compression rates represent the total reduction\nin communication, including the cost to transmit the updated model parameters. Larger BLEU\nscore is better. For both models, top-k SGD with k = 100, 000 achieves a higher BLEU score than\nuncompressed distributed SGD. This difference may be within the error bars, but if not, it may be\nthat stepping in only the direction of the top-k is serving as a regularizer on the optimizer. Our\nmain experiments are on the transformer model, for which we run additional experiments using 20x\ncompression that we did not complete for the LSTM model.\n\nrequest the P k coordinates, nor to specify which k model parameters have been updated, since these\nquantities can be ef\ufb01ciently coded, and contribute little to the overall communication.\nGiven that our algorithm involves a second round of communication in which P k gradient elements\nare transmitted, we investigate the tradeoff between a large sketch size and a large value of P .\nApproaching a sketch size of zero corresponds to using a weight update that is the top-k of a\nrandomly chosen set of P k gradient coordinates. Experiments with extremely small sketch size |S|\nor extremely small values of P tended to diverge or achieve very low BLEU score. For values of\n|S|/P k closer to 1, we plot learning curves in Figure 1. As expected, uncompressed SGD trains\nfastest, followed by top-k SGD, then 20x compression SKETCHED-SGD, then 40x compression\nSKETCHED-SGD. For the two 20x compression runs, the ratio of the sketch size to the number\nof exact gradient values computed has little effect on convergence speed. However, the higher\ncompression runs prefer a relatively larger value of P .\n\n5.4 Large W\n\nTo re-iterate, the per-worker communication cost for SKETCHED-SGD is not only sub-linear in d,\nbut also independent of W . To demonstrate the power of this experimentally, we train a residual\n\n5Sketch size: 5 rows by 1M columns; P = 36.\n6Sketch size: 15 rows by 180,000 columns; P = 16.\n7Sketch size: 5 rows by 180,000 columns, P = 26\n\n7\n\n103104105106k2021222324252627BLEUTrue Top-k (Transformer)Vanilla (Transformer)True Top-k (LSTM)Vanilla (LSTM)102104106k0.8000.8250.8500.8750.9000.9250.950Test AccuracyTrue Top-kVanilla\fFigure 3: Tradeoff between compression and model accuracy for a residual network trained on\nCIFAR-10. We show results for k = 50, 000 as well as k = 100, 000, and color code each trained\nmodel based on the ratio of sketch size to the cost of the second round of communication. The (nearly\noverlapping) solid orange and dashed blue lines show the accuracy achieved by top\u2212k SGD for the\ntwo values of k, and the black line shows the accuracy achieved by uncompressed distributed SGD.\nAll models in this plot were trained with 4 workers.\n\nnetwork on the CIFAR-10 dataset with SKETCHED-SGD, using up to 256 workers [Krizhevsky and\nHinton, 2009]. We compare to local top-k, a method where each worker computes and transmits\nonly the top-k elements of its gradient. The version of local top-k SGD we compare to is similar\nto Deep Gradient Compression, except we do not clip gradients, and we warm up the learning rate\ninstead of the sparsity [Lin et al., 2017]. Results are shown in Figure 4. Neither algorithm sees an\nappreciable drop in accuracy with more workers, up to W = 256. However, while the communication\ncost of SKETCHED-SGD is constant in W , the communication cost for local top-k scales with W\nuntil reaching \u0398(d). This scaling occurs because the local top-k of each worker might be disjoint,\nleading to as many as kW parameters being updated. In practice, we do in fact observe nearly linear\nscaling of the number of parameters updated each iteration, until saturating at d (dashed orange line\nin Figure 4). For W = 256, the communication of the updated model parameters back to each worker\nis nearly dense (d \u2248 6.5 \u00d7 106), reducing the overall compression of local top-k to at best \u223c 2\u00d7.\nFor a \ufb01xed small number of workers (W = 4), we also investigate the tradeoff between compression\nrate and \ufb01nal test accuracy. Figure 3 shows this tradeoff for two values of k and a wide range of\nsketch sizes and values of P . As expected, increasing the compression rate leads to decreasing test\naccuracy. In addition, as evidenced by the color coding, using a very large sketch size compared to\nP k tends to yield poor results. Although high compression rates decrease accuracy, in our experience,\nit is possible to make up for this accuracy drop by training longer. For example, choosing one of the\npoints in Figure 3, training with 17x compression for the usual number of iterations gives 92.5% test\naccuracy. Training with 50% more iterations (reducing to 11x overall compression) restores accuracy\nto 94%. In Figure 3, every model is trained for the same number of iterations.\n\n6 Discussion\n\nIn this work we introduce SKETCHED-SGD, an algorithm for reducing the communication cost in\ndistributed SGD using sketching. We provide theoretical and experimental evidence that our method\ncan help alleviate the dif\ufb01culties of scaling SGD to many workers. While uncompressed distributed\nSGD requires communication of size 2d, and other gradient compressions improve this to O(d)\n\n8\n\n020406080100120140Total Communication Compression0.30.40.50.60.70.80.9Test AccuracyVanillatop-k=50,000top-k=100,000010203040506070|S|/Pk01020300.880.900.920.94\fFigure 4: Comparison between SKETCHED-SGD and local top-k SGD on CIFAR10. Neither\nalgorithm sees an appreciable drop in performance for up to 256 workers, but the amount of com-\nmunication required for local top-k grows quickly to \u2248 d = 6.5 \u00d7 106 as the number of workers\nincreases. As a result, the best overall compression that local top-k can achieve for many workers is\n2x.\n\nor O(W ), SKETCHED-SGD further reduces the necessary communication to O(log d). Besides\nreducing communication, our method provably converges at the same rate as SGD, and in practice\nwe are able to reduce the total communication needed by up to 40x without experiencing a loss in\nmodel quality.\nA number of other techniques for ef\ufb01cient training could be combined with SKETCHED-SGD, includ-\ning gradient quantization and asynchronous updates. We expect that the advantages asynchronous\nupdates bring to regular SGD will carry over to SKETCHED-SGD. And given that elements of gradi-\nent sketches are sums of gradient elements, we expect that quantizing sketches will lead to similar\ntradeoffs as quantizing the gradients themselves. Preliminary experiments show that quantizing\nsketches to 16 bits when training our ResNets on CIFAR-10 leads to no drop in accuracy, but we leave\na full evaluation of combining quantization, as well as asynchronous updates, with SKETCHED-SGD\nto future work.\nMachine learning models are constantly growing in size (e.g. OpenAI\u2019s GPT-2, a transformer with\n1.5 billion parameters [Radford et al., 2019]), and training is being carried out on a larger and larger\nnumber of compute nodes. As communication increasingly becomes a bottleneck for large-scale\ntraining, we argue that a method that requires only O(log d) communication has the potential to\nenable a wide range of machine learning workloads that are currently infeasible, from highly parallel\ntraining in the cloud, to Federated Learning at the edge [McMahan et al., 2016].\n\n7 Acknowledgements\n\nThis research was supported, in part, by NSF BIGDATA grants IIS-1546482 and IIS-1838139, NSF\nCAREER grant 1652257, ONR Award N00014-18-1-2364 and the Lifelong Learning Machines\nprogram from DARPA/MTO. This material is based upon work supported by the National Science\nFoundation Graduate Research Fellowship under Grant No. DGE 1752814.\n\nReferences\nPankaj K Agarwal, Graham Cormode, Zengfeng Huang, Jeff M Phillips, Zhewei Wei, and Ke Yi.\n\nMergeable summaries. ACM Transactions on Database Systems (TODS), 38(4):26, 2013.\n\nAmirali Aghazadeh, Ryan Spring, Daniel Lejeune, Gautam Dasarathy, Anshumali Shrivastava,\nand richard baraniuk. MISSION: Ultra large-scale feature selection using count-sketches. In\nJennifer Dy and Andreas Krause, editors, Proceedings of the 35th International Conference on\nMachine Learning, volume 80 of Proceedings of Machine Learning Research, pages 80\u201388,\n\n9\n\n22232425262728Number of Workers0.9280.9300.9320.9340.9360.938Test Accuracy0e61e62e63e64e65e66e6Back CommunicationSketched (9x comp.)Local Top-k\fStockholmsm\u00e4ssan, Stockholm Sweden, 10\u201315 Jul 2018. PMLR. URL http://proceedings.\nmlr.press/v80/aghazadeh18a.html.\n\nDan Alistarh, Demjan Grubic, Jerry Li, Ryota Tomioka, and Milan Vojnovic. Qsgd: Communication-\nef\ufb01cient sgd via gradient quantization and encoding. In Advances in Neural Information Processing\nSystems, pages 1709\u20131720, 2017.\n\nDan Alistarh, Torsten Hoe\ufb02er, Mikael Johansson, Nikola Konstantinov, Sarit Khirirat, and C\u00e9dric\nRenggli. The convergence of sparsi\ufb01ed gradient methods. In Advances in Neural Information\nProcessing Systems, pages 5977\u20135987, 2018.\n\nNoga Alon, Yossi Matias, and Mario Szegedy. The space complexity of approximating the frequency\n\nmoments. Journal of Computer and system sciences, 58(1):137\u2013147, 1999.\n\nJeremy Bernstein, Yu-Xiang Wang, Kamyar Azizzadenesheli, and Anima Anandkumar. signsgd:\n\ncompressed optimisation for non-convex problems. arXiv preprint arXiv:1802.04434, 2018.\n\nVladimir Braverman, Stephen R Chestnut, Nikita Ivkin, Jelani Nelson, Zhengyu Wang, and David P\nWoodruff. Bptree: An (cid:96)2 heavy hitters algorithm using constant memory. In Proceedings of\nthe 36th ACM SIGMOD-SIGACT-SIGAI Symposium on Principles of Database Systems, pages\n361\u2013376. ACM, 2017.\n\nMoses Charikar, Kevin Chen, and Martin Farach-Colton. Finding frequent items in data streams. In\nInternational Colloquium on Automata, Languages, and Programming, pages 693\u2013703. Springer,\n2002.\n\nCody Coleman, Deepak Narayanan, Daniel Kang, Tian Zhao, Jian Zhang, Luigi Nardi, Peter Bailis,\nKunle Olukotun, Chris R\u00e9, and Matei Zaharia. Dawnbench: An end-to-end deep learning bench-\nmark and competition. Training, 100(101):102, 2017.\n\nGraham Cormode and Shan Muthukrishnan. An improved data stream summary: the count-min\n\nsketch and its applications. Journal of Algorithms, 55(1):58\u201375, 2005.\n\nNoah Golmant, Nikita Vemuri, Zhewei Yao, Vladimir Feinberg, Amir Gholami, Kai Rothauge,\nMichael W Mahoney, and Joseph Gonzalez. On the computational inef\ufb01ciency of large batch sizes\nfor stochastic gradient descent. arXiv preprint arXiv:1811.12941, 2018.\n\nPriya Goyal, Piotr Doll\u00e1r, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola,\nAndrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch sgd: Training imagenet\nin 1 hour. arXiv preprint arXiv:1706.02677, 2017.\n\nElad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better: closing the generaliza-\ntion gap in large batch training of neural networks. In Advances in Neural Information Processing\nSystems, pages 1731\u20131741, 2017.\n\nNikita Ivkin, Zaoxing Liu, Lin F Yang, Srinivas Suresh Kumar, Gerard Lemson, Mark Neyrinck,\nAlexander S Szalay, Vladimir Braverman, and Tamas Budavari. Scalable streaming tools for ana-\nlyzing n-body simulations: Finding halos and investigating excursion sets in one pass. Astronomy\nand computing, 23:166\u2013179, 2018.\n\nJiawei Jiang, Fangcheng Fu, Tong Yang, and Bin Cui. Sketchml: Accelerating distributed machine\nlearning with data sketches. In Proceedings of the 2018 International Conference on Management\nof Data, pages 1269\u20131284. ACM, 2018.\n\nSai Praneeth Karimireddy, Quentin Rebjock, Sebastian U Stich, and Martin Jaggi. Error feedback\n\ufb01xes signsgd and other gradient compression schemes. arXiv preprint arXiv:1901.09847, 2019.\nNitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter\nTang. On large-batch training for deep learning: Generalization gap and sharp minima. arXiv\npreprint arXiv:1609.04836, 2016.\n\nGuillaume Klein, Yoon Kim, Yuntian Deng, Jean Senellart, and Alexander Rush. OpenNMT: Open-\nsource toolkit for neural machine translation. In Proceedings of ACL 2017, System Demonstrations,\npages 67\u201372, Vancouver, Canada, July 2017. Association for Computational Linguistics. URL\nhttps://www.aclweb.org/anthology/P17-4012.\n\n10\n\n\fAnastasia Koloskova, Sebastian U Stich, and Martin Jaggi. Decentralized stochastic optimization\nand gossip algorithms with compressed communication. arXiv preprint arXiv:1902.00340, 2019.\nJakub Kone\u02c7cn`y, H Brendan McMahan, Felix X Yu, Peter Richt\u00e1rik, Ananda Theertha Suresh, and\nDave Bacon. Federated learning: Strategies for improving communication ef\ufb01ciency. arXiv\npreprint arXiv:1610.05492, 2016.\n\nAlex Krizhevsky and Geoffrey Hinton. Learning multiple layers of features from tiny images.\n\nTechnical report, Citeseer, 2009.\n\nMu Li, Tong Zhang, Yuqiang Chen, and Alexander J Smola. Ef\ufb01cient mini-batch training for\nstochastic optimization. In Proceedings of the 20th ACM SIGKDD international conference on\nKnowledge discovery and data mining, pages 661\u2013670. ACM, 2014.\n\nYujun Lin, Song Han, Huizi Mao, Yu Wang, and William J Dally. Deep gradient compression:\nReducing the communication bandwidth for distributed training. arXiv preprint arXiv:1712.01887,\n2017.\n\nSiyuan Ma, Raef Bassily, and Mikhail Belkin. The power of interpolation: Understanding the\neffectiveness of sgd in modern over-parametrized learning. arXiv preprint arXiv:1712.06559,\n2017.\n\nHoria Mania, Xinghao Pan, Dimitris Papailiopoulos, Benjamin Recht, Kannan Ramchandran, and\nMichael I Jordan. Perturbed iterate analysis for asynchronous stochastic optimization. arXiv\npreprint arXiv:1507.06970, 2015.\n\nH. Brendan McMahan, Eider Moore, Daniel Ramage, and Blaise Ag\u00fcera y Arcas. Federated\nlearning of deep networks using model averaging. CoRR, abs/1602.05629, 2016. URL http:\n//arxiv.org/abs/1602.05629.\n\nShanmugavelayutham Muthukrishnan et al. Data streams: Algorithms and applications. Foundations\n\nand Trends R(cid:13) in Theoretical Computer Science, 1(2):117\u2013236, 2005.\n\nAdam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito,\nZeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in\npytorch, 2017.\n\nAlec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. Language\n\nmodels are unsupervised multitask learners. OpenAI Blog, 1:8, 2019.\n\nAlexander Rakhlin, Ohad Shamir, Karthik Sridharan, et al. Making gradient descent optimal for\nstrongly convex stochastic optimization. In ICML, volume 12, pages 1571\u20131578. Citeseer, 2012.\nFrank Seide, Hao Fu, Jasha Droppo, Gang Li, and Dong Yu. 1-bit stochastic gradient descent and its\napplication to data-parallel distributed training of speech dnns. In Fifteenth Annual Conference of\nthe International Speech Communication Association, 2014.\n\nChristopher J Shallue, Jaehoon Lee, Joe Antognini, Jascha Sohl-Dickstein, Roy Frostig, and George E\nDahl. Measuring the effects of data parallelism on neural network training. arXiv preprint\narXiv:1811.03600, 2018.\n\nSebastian U Stich, Jean-Baptiste Cordonnier, and Martin Jaggi. Sparsi\ufb01ed sgd with memory. In\n\nAdvances in Neural Information Processing Systems, pages 4452\u20134463, 2018.\n\nNikko Strom. Scalable distributed dnn training using commodity gpu cloud computing. In Sixteenth\n\nAnnual Conference of the International Speech Communication Association, 2015.\n\nKai Sheng Tai, Vatsal Sharan, Peter Bailis, and Gregory Valiant. Sketching linear classi\ufb01ers over\ndata streams. In Proceedings of the 2018 International Conference on Management of Data, pages\n757\u2013772. ACM, 2018.\n\nWei Wen, Cong Xu, Feng Yan, Chunpeng Wu, Yandan Wang, Yiran Chen, and Hai Li. Terngrad:\nTernary gradients to reduce communication in distributed deep learning. In Advances in neural\ninformation processing systems, pages 1509\u20131519, 2017.\n\nYang You, Igor Gitman, and Boris Ginsburg. Large batch training of convolutional networks. arXiv\n\npreprint arXiv:1708.03888, 2017.\n\n11\n\n\f", "award": [], "sourceid": 7205, "authors": [{"given_name": "Nikita", "family_name": "Ivkin", "institution": "Amazon"}, {"given_name": "Daniel", "family_name": "Rothchild", "institution": "UC Berkeley"}, {"given_name": "Enayat", "family_name": "Ullah", "institution": "Johns Hopkins University"}, {"given_name": "Vladimir", "family_name": "braverman", "institution": "Johns Hopkins University"}, {"given_name": "Ion", "family_name": "Stoica", "institution": "UC Berkeley"}, {"given_name": "Raman", "family_name": "Arora", "institution": "Johns Hopkins University"}]}