{"title": "Adaptive Communication Strategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD", "book": "Proceedings of Machine Learning and Systems", "page_first": 212, "page_last": 229, "abstract": "Large-scale machine learning training, in particular distributed stochastic gradient descent, needs to be robust to inherent system variability such as node straggling and random communication delays. This work considers a distributed training framework where each worker node is allowed to perform local model updates and the resulting models are averaged periodically. We analyze the true speed of error convergence with respect to wall-clock time (instead of the number of iterations), and analyze how it is affected by the frequency of averaging. The main contribution is the design of AdaComm, an adaptive communication strategy that starts with infrequent averaging to save communication delay and improve convergence speed, and then increases the communication frequency in order to achieve a low error floor. Rigorous experiments on training deep neural networks show that AdaComm can take 3 times less time than fully synchronous SGD, and still reach the same final training loss.", "full_text": " ADAPTIVE COMMUNICATION STRATEGIES TO ACHIEVE THE BEST\r\n ERROR-RUNTIMETRADE-OFFINLOCAL-UPDATESGD\r\n Jianyu Wang1 GauriJoshi1\r\n ABSTRACT\r\n Large-scale machine learning training, in particular, distributed stochastic gradient descent, needs to be robust\r\n to inherent system variability such as node straggling and random communication delays. This work considers\r\n a distributed training framework where each worker node is allowed to perform local model updates and the\r\n resulting models are averaged periodically. We analyze the true speed of error convergence with respect to\r\n wall-clock time (instead of the number of iterations), and analyze how it is affected by the frequency of averaging.\r\n Themaincontribution is the design of ADACOMM, an adaptive communication strategy that starts with infrequent\r\n averaging to save communication delay and improve convergence speed, and then increases the communication\r\n frequency in order to achieve a low error \ufb02oor. Rigorous experiments on training deep neural networks show that\r\n ADACOMMcantake3\u00d7lesstimethanfullysynchronousSGDandstillreachthesame\ufb01naltrainingloss.\r\n 1 INTRODUCTION is run on a single dedicated server. In distributed SGD,\r\n Stochastic gradient descent (SGD) is the backbone of state- whichisoftenrunonsharedcloudinfrastructure, the second\r\n of-the-art supervised learning, which is revolutionizing in- factor depends on several aspects such as the number of\r\n ference and decision-making in many diverse applications. worker nodes, their local computation and communication\r\n Classical SGDwasdesignedtoberunonasinglecomputing delays, and the protocol (synchronous, asynchronous or pe-\r\n node, and its error-convergence with respect to the number riodic) used to aggregate their gradients. Hence, in order\r\n of iterations has been extensively analyzed and improved to achieve the fastest convergence speed we need: 1) opti-\r\n via accelerated SGD methods. Due to the massive training mization techniques (eg. variable learning rate) to maximize\r\n data-sets and neural network architectures used today, it the error-convergence rate with respect to iterations, and 2)\r\n has became imperative to design distributed SGD imple- scheduling techniques (eg. straggler mitigation, infrequent\r\n mentations, where gradient computation and aggregation is communication) to maximize the number of iterations com-\r\n parallelized across multiple worker nodes. Although paral- pleted per second. These directions are inter-dependent and\r\n lelism boosts the amount of data processed per iteration, it need to be explored together rather than in isolation. While\r\n exposes SGD to unpredictable node slowdown and commu- manyworks have advanced the \ufb01rst direction, the second\r\n nication delays stemming from variability in the computing is less explored from a theoretical point of view, and the\r\n infrastructure. Thus, there is a critical need to make dis- juxtaposition of both is an unexplored problem.\r\n tributed SGD fast, yet robust to system variability. Local-Update SGD to Reduce Communication Delays.\r\n NeedtoOptimizeConvergenceintermsofErrorversus Apopulardistributed SGD implementation is the parame-\r\n Wall-clock Time. The convergence speed of distributed ter server framework (Dean et al., 2012; Cui et al., 2014;\r\n SGDisaproductoftwofactors: 1) the error in the trained Li et al., 2014; Gupta et al., 2016; Mitliagkas et al., 2016)\r\n model versus the number of iterations, and 2) the number where in each iteration, worker nodes compute gradients on\r\n of iterations completed per second. Traditional single-node one mini-batch of data and a central parameter server ag-\r\n SGD analysis focuses on optimizing the \ufb01rst factor, be- gregates these gradients (synchronously or asynchronously)\r\n cause the second factor is generally a constant when SGD and updates the parameter vector x. The constant commu-\r\n nication between the parameter server and worker nodes\r\n 1Department of Electrical & Computer Engineering, Carnegie in each iteration can be expensive and slow in bandwidth-\r\n Mellon University, Pittsburgh, PA, USA. Correspondence to: limited computed environments. Recently proposed dis-\r\n Jianyu Wang , Gauri Joshi . et al., 2015; Chaudhari et al., 2017), Federated Learning\r\n Proceedings of the 2nd SysML Conference, Palo Alto, CA, USA, (McMahanetal., 2016; Smith et al., 2017b) and decentral-\r\n 2019. Copyright 2019 by the author(s). ized SGD (Lian et al., 2017; Jiang et al., 2017) save this\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n withalarger\u03c4 andgraduallydecreaseitasthemodelreaches\r\n s s closer to convergence. Such an adaptive strategy will offer\r\n os os\r\n l l a win-win in the error-runtime trade-off by achieving fast\r\n ng ng Adaptive Comm. convergence as well as low error \ufb02oor. To the best of our\r\n ni ni\r\n i i\r\n a Large comm. period a\r\n r r knowledge, this is the \ufb01rst work to propose an adaptive\r\n T T\r\n Small comm. period communication frequency strategy.\r\n # Iterations Wall clock time Main Contributions. This paper focuses on periodic-\r\n # Iteration \u00e0 Elapsed time averaging local-update SGD (PASGD) and makes the fol-\r\n Change x-axis lowing main contributions:\r\n 1. We \ufb01rst analyze the runtime per iteration of periodic\r\n Figure 1. This work departs from the traditional view of consid- averaging SGD (PASGD) by modeling local comput-\r\n ering error-convergence with respect to the number of iterations, ing time and communication delays as random vari-\r\n and instead considers the true convergence in terms of error ver- ables, and quantify its runtime speed-up over fully\r\n sus wall-clock time. Adaptive strategies that start with infrequent synchronous SGD. A novel insight from this analysis\r\n model-averaging and increase the communication frequency can is that periodic-averaging strategy not only reduces the\r\n achieve the best error-runtime trade-off. communication delay but also mitigates synchroniza-\r\n tion delays in waiting for slow or straggling nodes.\r\n communication cost by allowing worker nodes to perform\r\n local updates to the parameter x instead of just computing 2. By combining the runtime analysis error-convergence\r\n gradients. The resulting locally trained models (which are analysis of PASGD (Wang & Joshi, 2018), we can ob-\r\n different due to variability in training data across nodes) tain the error-runtime trade-off for different values of \u03c4.\r\n are periodically averaged through a central server, or via di- Usingthis combined error-runtime trade-off, we derive\r\n rect inter-worker communication. This local-update strategy an expression of the optimal communication period,\r\n has been shown to offer signi\ufb01cant speedup in deep neural which can serve as a useful guideline in practice.\r\n network training (Lian et al., 2017; McMahan et al., 2016). 3. Based on the observations in runtime and conver-\r\n Error-RuntimeTrade-offs in Local-Update SGD. While genceanalysis, we develop an adaptive communication\r\n local updates reduce the communication-delay incurred per scheme: ADACOMM. Experimentsontraining VGG-\r\n iteration, discrepancies between local models can result in 16 and ResNet-50 deep neural networks and differ-\r\n an inferior error-convergence. For example, consider the ent settings (with/without momentum, \ufb01xed/decaying\r\n case of periodic-averaging SGD (PASGD) where each of m learning rate) show that ADACOMM can give a 3\u00d7\r\n worker nodes makes \u03c4 local updates, and the resulting mod- runtime speed-up and still reach the same low training\r\n els are averaged after every \u03c4 iterations (Moritz et al., 2015; loss as fully synchronous SGD.\r\n Su & Chen, 2015; Chen & Huo, 2016; Seide & Agarwal, 4. We present a convergence analysis for PASGD with\r\n 2016; Zhang et al., 2016; Zhou & Cong, 2017; Lin et al., variable communication period \u03c4 and variable learn-\r\n 2018). A larger value of \u03c4 leads to slower convergence with ing rate \u03b7, generalizing previous work (Wang & Joshi,\r\n respect to the number of iterations as illustrated in Figure 1. 2018). This analysis shows that decaying \u03c4 provides\r\n However, if we look at the true convergence with respect similar convergence bene\ufb01ts as decaying learning rate,\r\n to the wall-clock time, then a larger \u03c4, that is, less frequent the difference being that varying \u03c4 improves the true\r\n averaging, saves communication delay and reduces the run- convergence with respect to the wall-clock time. Adap-\r\n time per iteration. While some recent theoretical works tive communication can also be used in conjunction\r\n (Zhou & Cong, 2017; Yu et al., 2018; Wang & Joshi, 2018; with existing learning rate schedules.\r\n Stich, 2018) study this dependence of the error-convergence\r\n with respect to the number of iterations as \u03c4 varies, achiev- Although we focus on periodic simple-averaging of lo-\r\n ing a provably-optimal speed-up in the true convergence cal models, the insights on error-runtime trade-offs and\r\n with respect to wall-clock time is an open problem that we adaptive communication strategies are directly extendable\r\n aimtoaddress in this work. to other communication-ef\ufb01cient SGD algorithms includ-\r\n Need for Adaptive Communication Strategies. In the ing Federated Learning (McMahan et al., 2016), Elastic-\r\n error-runtime in Figure 1, we observe a trade-off between Averaging (Zhang et al., 2015) and Decentralized averag-\r\n the convergence speed and the error \ufb02oor when the number ing (Jiang et al., 2017; Lian et al., 2017), as well as syn-\r\n of local updates \u03c4 is varied. A larger \u03c4 gives a faster initial chronous/asynchronous distributed SGD with a central pa-\r\n drop in the training loss but results in a higher error \ufb02oor. rameter server (Dean et al., 2012; Cui et al., 2014; Dutta\r\n This calls for adaptive communication strategies that start et al., 2018).\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n 2 PROBLEMFRAMEWORK x =x(2) =x(1)\r\n 1 1 1\r\n Empirical Risk Minimization via Mini-batch SGD. Our x(2)\r\n 2 x(1)\r\n objective is to minimize an objective function F(x), the 2\r\n x(2)\r\n empirical risk function, with respect to model parameters 3 x(1)\r\n denoted by x \u2208 Rd. The training dataset is denoted by 3\r\n S = {s ,...,s }, where s represents the i-th labeled\r\n 1 N i x\r\n data point. The objective function can be expressed as the 4\r\n empirical risk calculated using the training data and is given ! = 3 local steps\r\n by x at each worker\r\n \" # 7\r\n N\r\n min F(x):= 1 Xf(x;si) (1) Figure 2. Illustration of PASGD in the model parameter space\r\n x\u2208Rd N\r\n i=1 for m = 2 workers. The discrepancy between the local models\r\n where f(x;s ) is the composite loss function at the ith increases with the number of local updates, \u03c4 = 3.\r\n i\r\n data point. In classic mini-batch stochastic gradient descent x1 x4 x7\r\n (SGD)(Dekeletal., 2012), updates to the parameter vector Worker1\r\n xareperformedasfollows. If\u03bek \u2282 S representsarandomly Worker2\r\n sampled mini-batch, then the update rule is\r\n x =x \u2212\u03b7g(x ;\u03be ) (2) Figure 3. Illustration of PASGD in the time space for m = 2\r\n k+1 k k k\r\n th\r\n where\u03b7 denotesthelearning rate and the stochastic gradient and \u03c4 = 3. Lengths of the colored arrows at the i worker are\r\n Y ,thelocal-update times, which are i.i.d. across workers and\r\n is de\ufb01ned as: g(x;\u03be) = 1 P \u2207f(x;s ). For simplic- i,k\r\n |\u03be| s \u2208\u03be i updates. The blue block represents the communication delay for\r\n i\r\n ity, we will use g(xk) instead of g(xk;\u03bek) in the rest of each model-averaging step.\r\n the paper. A complete review of convergence properties of\r\n serial SGD can be found in (Bottou et al., 2018). Theanalysis of fully synchronous SGD is identical to serial\r\n Periodic-Averaging SGD (PASGD). We consider a dis- SGDwithm-foldlargemini-batch size.\r\n tributed SGD framework with m worker nodes where all Local Computation Times and Communication Delay.\r\n workers can communicate with others via a central server In order to analyze the effect of \u03c4 on the expected runtime\r\n or via direct inter-worker communication. In periodic- per iteration, we consider the following delay model. The\r\n averaging SGD, all workers start at the same initial point time taken by the ith worker to compute a mini-batch gra-\r\n x1. Each worker performs \u03c4 local mini-batch SGD updates dient at the kth local-step is modeled a random variable\r\n according to (2), and the local models are averaged by a Y \u223cF ,assumedto be i.i.d. across workers and mini-\r\n fusion node or by performing an all-node broadcast. The i,k Y\r\n workers then update their local models with the averaged batches. The communication delay is a random variable D\r\n model, as illustrated in Figure 2. Thus, the overall update for each all-node broadcast, as illustrated in Figure 3. The\r\n rule at the ith worker is given by value of random variable D can depend on the number of\r\n ( P workers as follows.\r\n 1 m [x(j) \u2212\u03b7g(x(j))], kmod\u03c4 =0\r\n x(i) = m j=1 k k D=D \u00b7s(m) (5)\r\n k+1 x(i) \u2212 \u03b7g(x(i)), otherwise 0\r\n k k\r\n (3) where D0 represents the time taken for each inter-node\r\n where x(i) denote the model parameters in the i-th worker communication, and s(m) describes how the delay scales\r\n k with the number of workers, which depends on the im-\r\n after k iterations and \u03c4 is de\ufb01ned as the communication plementation and system characteristics. For example, in\r\n period. Note that the iteration index k corresponds to the the parameter server framework, the communication delay\r\n local iterations, and not the number of averaging steps. can be proportional to 2log (m) by exploiting a reduction\r\n 2\r\n Special Case (\u03c4 = 1): Fully Synchronous SGD. When tree structure (Iandola et al., 2016). We assume that s(m)\r\n \u03c4 = 1, that is, the local models are synchronized after is known beforehand for the communication-ef\ufb01cient dis-\r\n every iteration, periodic-averaging SGD is equivalent to tributed SGD framework under consideration.\r\n fully synchronous SGD which has the update rule Convergence Criteria. In the error-convergence analysis,\r\n \" m # since the objective function is non-convex, we use the ex-\r\n xk+1 = xk \u2212\u03b7 1 Xg(xk;\u03be(i)) . (4) pected gradient norm as a an indicator of convergence fol-\r\n m k\r\n i=1 lowing (Ghadimi & Lan, 2013; Bottou et al., 2018). We say\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n the algorithm achieves an \u01eb-suboptimal solution if: 3.2 RuntimeBene\ufb01tsofPeriodicAveragingStrategy\r\n \u0014 2\u0015 Speed-up over fully synchronous SGD. We evaluate the\r\n E min k\u2207F(xk)k \u2264\u01eb. (6)\r\n k\u2208[1,K] speed-upofperiodic-averagingSGDoverfullysynchronous\r\n When\u01ebisarbitrarily small, this condition can guarantee the SGDfor different Y and D to demonstrate how the rela-\r\n algorithm converges to a stationary point. tive value of computation versus communication delays\r\n affects the speed-up. Consider the simplest case where Y\r\n 3 JOINTLY ANALYZING RUNTIME AND and D are constants and de\ufb01ne \u03b1 = D/Y, the communi-\r\n ERROR-CONVERGENCE cation/computation ratio. Besides systems aspects such as\r\n network bandwidth and computing capacity, for deep neural\r\n 3.1 RuntimeAnalysis network training, this ratio \u03b1 also depends on the size of the\r\n neural network model and the mini-batch size. See Figure 8\r\n Wenowpresentacomparisonoftheruntimeperiteration for a comparison of the communication/computation delays\r\n of periodic-averaging SGD with fully synchronous SGD to of common deep neural network architectures. Then Y,\r\n illustrate how increasing \u03c4 can lead to a large runtime speed- Y , Y are all equal to Y , and the ratio of E[T ]\r\n m:m m:m sync\r\n up. Another interesting effect of performing more local and E[T ] is given by\r\n P-Avg\r\n update \u03c4 is that it mitigates the slowdown due to straggling E[T ]\r\n worker nodes. sync = Y +D = 1+\u03b1 (12)\r\n RuntimePerIterationofFullySynchronousSGD.Fully E[TP-Avg] Y +D/\u03c4 1+\u03b1/\u03c4\r\n synchronous SGD is equivalent to periodic-averaging SGD Figure 4 shows the speed-up for different values of \u03b1 and\r\n with\u03c4 = 1. Eachofthemworkerscomputesthegradientof \u03c4. When D is comparable with Y (\u03b1 = 0.9), periodic-\r\n one mini-batch and updates the parameter vector x, which averaging SGD (PASGD) can be almost twice as fast as\r\n takes time Y at the ith worker1. After all workers \ufb01nish fully synchronous SGD.\r\n i,1\r\n their local updates, an all-node broadcast is performed to\r\n synchronize and average the models. Thus, the total time to\r\n complete each iteration is given by 2\r\n = 0.1\r\n T =max(Y ,Y ,...,Y ) +D (7) 1.8 = 0.5\r\n sync 1,1 2,1 m,1 = 0.9\r\n E[Tsync] = E[Ym:m]+E[D] (8) 1.6\r\n where Y are i.i.d. random variables with probability dis-\r\n i,1 1.4\r\n tribution FY and D is the communication delay. The term\r\n Y denotes the highest order statistic of m i.i.d. random\r\n m:m 1.2\r\n variables (David & Nagaraja, 2003).\r\n Runtime Per Iteration of Periodic-Averaging SGD Speedup over fully sync SGD10204060 80 100\r\n (PASGD). In periodic-averaging SGD, each worker per- Communication period\r\n forms \u03c4 local updates before communicating with other\r\n workers. Let us denote the average local computation time Figure 4. The speed-up offered by using periodic-averaging SGD\r\n at the ith worker by increases with \u03c4 (the communication period) and with the com-\r\n Y +Y +...Y munication/computation delay ratio \u03b1 = D/Y , where D is the\r\n Y = i,1 i,2 i,\u03c4 (9)\r\n i \u03c4 all-node broadcast delay and Y is the time taken for each local\r\n Since the communication delay D is amortized over \u03c4 itera- update at a worker.\r\n tions, the average computation time per iteration is\r\n D Straggler Mitigation due to Local Updates. Suppose that\r\n T =max(Y ,Y ,...,Y )+ (10) Y is exponentially distributed with mean y and variance\r\n P-Avg 1 2 m \u03c4\r\n y2. For fully synchronous SGD, the term E[Ym:m] in (8)\r\n E[D] P\r\n E[T ] = E[Y ] + (11) is equal to y m 1/i, which is approximately equal to\r\n P-Avg m:m \u03c4 i=1\r\n ylogm. Thus, the expected runtime per iteration of fully\r\n Thevalue of the \ufb01rst term Y m:m and how it compares with synchronous SGD (8) increases logarithmically with the\r\n Y depends on the probability distribution F of Y . numberofworkersm. Letuscomparethiswiththe scaling\r\n m:m Y\r\n 1Instead of local updates, typical implementations of fully of the runtime of periodic-averaging SGD (11). Here, Y\r\n synchronous SGD have a central server that performs the update. (9) is an Erlang random variable with mean y and variable\r\n Here we compare PASGD with fully synchronous SGD without a y2/\u03c4. Since the variance is \u03c4 times smaller than that of\r\n central parameter server. Y, the maximum order statistic E[Y m:m] is smaller than\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n E[Ym:m]. Figure 5 shows the probability distribution of\r\n T and T for exponentially distributed Y . Observe\r\n sync P-Avg Sync SGD\r\n that T has a much lighter tail. This is because the effect\r\n P-Avg 1 PASGD ( = 10)\r\n of the variability in Y on T is reduced due to the Y in\r\n P-Avg 0.8\r\n (8) being replaced by Y (which has lower variance) in (11).\r\n 0.6\r\n 0.4 0.4\r\n Sync SGD 0.2\r\n 0.3 2x less PASGD ( = 10) Gradient norm upper bound0\r\n 0 1000 2000 3000 4000\r\n 0.2 Total runtime\r\n Probability0.1 Figure 6. Illustration of theoretical error bound versus runtime in\r\n Theorem1. The runtime per iteration is generated under the same\r\n parameters as Figure 5. Other constants in (13) are set as follows:\r\n 00 2 4 6 8 F(x1) = 1,Finf = 0,\u03b7 = 0.08,L = 1,\u03c32 = 1.\r\n Runtime per iteration\r\n Figure 5. Probability distribution of runtime per iteration, where weplot theoretical bounds for both fully synchronous SGD\r\n communication delay D = 1, mean computation time y = 1, and (\u03c4 = 1) and PASGD. It is shown that although PASGD with\r\n numberofworkersm = 16. Dashlines represent the mean values. \u03c4 = 10 starts with a rapid drop, it will eventually converge\r\n to a high error \ufb02oor. This theoretical result is also corrobo-\r\n rated by experiments in Section 5. Another direct outcome\r\n 3.3 Joint Analysis with Error-convergence of Theorem 1 is the determination of the best communica-\r\n tion period that balances the \ufb01rst and last terms in (13). We\r\n In this subsection, we combine the runtime analysis with will discuss the selection of communication period later in\r\n previous error-convergence analysis for PASGD (Wang & Section 4.1.\r\n Joshi, 2018). Due to space limitations, we state the neces-\r\n sary theoretical assumptions in the Appendix; the assump- 4 ADACOMM: PROPOSEDADAPTIVE\r\n tions are similar to previous works (Zhou & Cong, 2017; COMMUNICATIONSTRATEGY\r\n Wang & Joshi, 2018) on the convergence of local-update\r\n SGDalgorithms. Inspired by the clear trade-off in the learning curve in Fig-\r\n Theorem 1 (Error-runtime Convergence of PASGD). ure 6, it would be better to have an adaptive communication\r\n For PASGD, under certain assumptions (stated in the Ap- strategy that starts with infrequent communication to im-\r\n pendix), if the learning rate satis\ufb01es \u03b7L+\u03b72L2\u03c4(\u03c4\u22121) \u2264 1, prove convergence speed, and then increases the frequency\r\n Y andDareconstants,andallworkersareinitializedatthe to achieve a low error \ufb02oor. In this section, we are going to\r\n samepoint x1, then after total T wall-clock time, the mini- develop the proposed adaptive communication scheme.\r\n mal expected squared gradient norm within T time interval The basic idea to adapt the communication is to choose\r\n will be bounded by: the communication period that minimizes the optimization\r\n 2[F(x )\u2212F ]\u0012 D\u0013 \u03b7L\u03c32 error at each wall-clock time. One way to achieve the idea is\r\n 1 inf Y + + +\u03b72L2\u03c32(\u03c4 \u22121) switching between the learning curves at their intersections.\r\n \u03b7T \u03c4 m However, without prior knowledge of various curves, it\r\n (13) would be dif\ufb01cult to determine the switch points.\r\n where L is the Lipschitz constant of the objective function Instead, wedividethewholetrainingprocedureintouniform\r\n and\u03c32 is the variance bound of mini-batch stochastic gra- wall-clock time intervals with the same length T . At the\r\n dients. 0\r\n beginning of each time interval, we select the best value of \u03c4\r\n TheproofofTheorem1ispresentedintheAppendix. From that has the fastest decay rate in the next T0 wall-clock time.\r\n the optimization error upper bound (13), one can easily If the interval length T0 is small enough and the best choice\r\n observe the error-runtime trade-off for different communi- of communication period for each interval can be precisely\r\n cation periods. While a larger \u03c4 reduces the runtime per estimated, then this adaptive scheme should achieve a win-\r\n iteration and let the \ufb01rst term in (13) become smaller, it also winintheerror-runtime trade-off as illustrated in Figure 7.\r\n addsadditionalnoiseandincreasesthelastterm. InFigure6, After setting the interval length, the next question is how\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n !! !! !! \u00e1\u00e1\u00e1\u00e1\u00e1\u00e1 !! result is consistent with the intuition that the trade-off\r\n s s 0 1 2 l between error-convergence and communication-ef\ufb01ciency\r\n os Switch point os\r\n l l varies over time. Compared to the initial phase of training,\r\n ng ng \r\n ni ni the bene\ufb01t of using a large communication period dimin-\r\n i i\r\n a a\r\n r Large comm. period r\r\n T Small comm. period T ishes as the model reaches close to convergence. At this\r\n later stage, a lower error \ufb02oor is more preferable to speeding\r\n 0 T 2T \u00e1\u00e1\u00e1 lT\r\n Wall clock time 0 0 0 upthe runtime.\r\n (b) Choose the best \u03c4 for each Remark1(ConnectiontoDecayingLearningRate). Us-\r\n (a) Switch between curves. time interval. ing a \ufb01xed learning rate in SGD leads to an error \ufb02oor at\r\n Figure 7. Illustration of communication period adaptation strate- convergence. To further reduce the error, practical SGD im-\r\n gies. Dash line denotes the learning curve using adaptive commu- plementations generally decay the learning rate or increase\r\n nication. the mini-batch size (Smith et al., 2017a; Goyal et al., 2017).\r\n AswesawfromtheconvergenceanalysisTheorem1,per-\r\n forming local updates adds additional noise in stochastic\r\n to estimate the best communication period for each time gradients, resulting in a higher error \ufb02oor convergence. De-\r\n interval. In Section 4.1 we use the error-runtime analysis in caying the communication period can gradually reduce the\r\n Section 3.3 to \ufb01nd the best \u03c4 at each time. variance of gradients and yield a similar improvement in\r\n convergence. Thus, adaptive communication strategies are\r\n 4.1 Determining the Best Communication Period for similar in spirit to decaying learning rate or increasing mini-\r\n EachTimeInterval batch size. The key difference is that here we are optimizing\r\n FromTheorem1,itcanbeobservedthatthere is an optimal the true error convergence with respect to wall-clock time\r\n value \u03c4\u2217 that minimizes the optimization error bound at rather than the number iterations.\r\n given wall-clock time. In particular, consider the simplest 4.2 Practical Considerations\r\n setting where Y and D are constants. Then, by minimizing\r\n the upper bound (13) over \u03c4, we obtain the following. Although (15) and (16) provide useful insights about how\r\n Theorem2. For PASGD, under the same assumptions as to adapt \u03c4 over time, it is still dif\ufb01cult to directly use them\r\n Theorem 1, the optimization error upper bound in (13) at in practice due to the Lipschitz constant L and the gradi-\r\n time T is minimized when the communication period is ent variance bound \u03c32 being unknown. For deep neural\r\n s networks, estimating these constants can be dif\ufb01cult and un-\r\n 2(F(x )\u2212F )D reliable due to the highly non-convex and high-dimensional\r\n \u03c4\u2217 = 1 inf . (14) loss surface. As an alternative, we propose a simpler rule\r\n \u03b73L2\u03c32T\r\n where we approximate F by 0, and divide (16) by (15) to\r\n inf\r\n obtain the basic communication period update rule:\r\n Theproofis straightforward by setting the derivative of (13) &sF(xt=lT ) '\r\n to zero. We present the details in the Appendix. Suppose all Basic update rule \u03c4 = 0 \u03c4 (17)\r\n l F(x ) 0\r\n workers starts from the same initial point x1 = xt=0 where t=0\r\n subscript t denotes the wall-clock time. Directly applying\r\n Theorem2tothe\ufb01rsttimeinterval, then the best choice of where \u2308a\u2309 is the ceil function to round a to the nearest inte-\r\n communication period is: ger \u2265 a. Since the objective function values (i.e., training\r\n loss) F(xt=lT ) and F(xt=0) can be easily obtained in the\r\n s 0\r\n 2(F(x ) \u2212F )D training, the only remaining thing now is to determine the\r\n \u03c4 = t=0 inf . (15)\r\n 0 \u03b73L2\u03c32T initial communication period \u03c40. We obtain a heuristic esti-\r\n 0 mate of \u03c4 by a simple grid search over different \u03c4 run for\r\n 0\r\n Similarly, for the l-th time interval, workers can be viewed one or two epochs each.\r\n as restarting training at a new initial point xt=lT . Applying\r\n 0 4.3 Re\ufb01nementstotheProposedAdaptiveStrategy\r\n Theorem2again,wehave\r\n s2(F(x ) \u2212F )D 4.3.1 Faster Decay When Training Saturates\r\n t=lT inf\r\n \u03c4 = 0 . (16)\r\n l \u03b73L2\u03c32T Thecommunicationperiod update rule (17) tends to give a\r\n 0 decreasing sequence {\u03c4 }. Nonetheless, it is possible that\r\n l\r\n Comparing (15) and (16), it is easy to see the generated the best value of \u03c4 for next time interval is larger than the\r\n l\r\n communication period sequence decreases along with the current one due to random noise in the training process.\r\n objective value F(xt) when the learning rate is \ufb01xed. This Besides, when the training loss gets stuck on plateaus and\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n decreases very slowly, (17) will result in \u03c4 saturating at adaptive strategy:\r\n l\r\n the same value for a long time. To address this issue, we s s\r\n borrow a idea used in classic SGD where the learning rate & \u03b73L2F(x ) ' & \u03b7 F(x ) '\r\n t=lT 0 t=lT\r\n \u03c4 = 0 0 0 \u03c4 \u2248 0 \u03c4 .\r\n is decayed by a factor \u03b3 when the training loss saturates l \u03b73L2 F(x ) 0 \u03b7 F(x ) 0\r\n for several epochs (Goyal et al., 2017). Similarly, in the l l t=0 l t=0\r\n our scheme, the communication period will be multiplied (20)\r\n by\u03b3 <1whenthe\u03c4l givenby(17)isnotstrictly less than\r\n \u03c4 . To be speci\ufb01c, the communication period for the lth Apart from coupling the communication period with learn-\r\n l\u22121\r\n time interval will be determined as follows: ing rate, when to decay the learning rate is another key\r\n \u0018 \u0019 \u0018 \u0019 design factor. In order to eliminate the noise introduced by\r\n \uf8f1 qF(xt=lT ) qF(xt=lT ) local updates, we choose to \ufb01rst gradually decay the commu-\r\n \uf8f2 0 \u03c4 , if 0 \u03c4 <\u03c4\r\n \u03c4 = F(xt=0) 0 F(xt=0) 0 l\u22121 . nication period to 1 and then decay the learning rate as usual.\r\n l \uf8f3\r\n \u03b3\u03c4 , otherwise For example, if the learning rate is scheduled to be decayed\r\n l\u22121 th\r\n (18) at the 80 epoch but at that time the communication period\r\n \u03c4 is still larger than 1, then we will continue use the current\r\n In the experiments, \u03b3 = 1/2 turns out to be a good choice. learning rate until \u03c4 = 1.\r\n One can obtain a more aggressive decay in \u03c4 by either\r\n l 4.4 Theoretical Guarantees for the Convergence of\r\n reducing the value of \u03b3 or introducing a slack variable s in\r\n qF(xt=lT ) ADACOMM\r\n the condition, such as \u2308 0 \u03c4 \u2309 +s < \u03c4 .\r\n F(xt=0) 0 l\u22121 In this subsection, we are going to provide a convergence\r\n 4.3.2 Incorporating Adaptive Learning Rate guarantee for the proposed adaptive communication scheme\r\n by extending the error analysis for PASGD. Without loss\r\n So far we consider a \ufb01xed learning rate \u03b7 for the local of generality, we will analyze an arbitrary communication\r\n SGDupdatesattheworkers. We now present an adaptive period sequence {\u03c4 ,...,\u03c4 }, where R represents the total\r\n 0 R\r\n communication strategy that adjusts \u03c4 for a given variable 2\r\n l communication rounds . It will be shown that a decreasing\r\n learning rate schedule, in order to obtain the best error- sequence of \u03c4 is bene\ufb01cial to the error-convergence rate.\r\n runtime trade-off. Suppose \u03b7l denotes the learning rate for Theorem 3 (Convergence of adaptive communication\r\n the lth time interval. Then, combining (15) and (16) again,\r\n wehave scheme). For PASGD with adaptive communication pe-\r\n riod and adaptive learning rate, suppose the learning rate\r\n &s\u03b73F(xt=lT ) ' remains same in each local update period. If the following\r\n \u03c4 = 0 0 \u03c4 . (19) conditions are satis\ufb01ed as R \u2192 \u221e,\r\n l \u03b73 F(x ) 0\r\n l t=0\r\n R R R\r\n X\u03b7\u03c4 \u2192\u221e,X\u03b72\u03c4 <\u221e,X\u03b73\u03c42<\u221e, (21)\r\n Observe that when the learning rate becomes smaller, the r r r r r r\r\n communication period \u03c4 increases. This result corresponds r=0 r=0 r=0\r\n l\r\n the intuition that a small learning rate reduces the discrep- then the averaged model x is guaranteed to converge to a\r\n ancy between the local models, and hence is more tolerant stationary point:\r\n to large communication periods.\r\n \"P P #\r\n R\u22121 \u03c4r 2\r\n Equation (19) states that the communication period should \u03b7 k\u2207F(x )k\r\n E r=0 r k=1 sr+k \u21920 (22)\r\n 3/2 P\r\n be proportional to (\u03b7 /\u03b7 ) . However, in practice, it is R\u22121\r\n 0 l \u03b7 \u03c4\r\n commontodecaythelearningrate10timesaftersomegiven r=0 r r\r\n numberofepochs. Thedramaticchangeoflearningratemay P\r\n where s = r\u22121\u03c4 .\r\n push the communication period to an unreasonably large r j=0 j\r\n value. In the experiments, we observe that when applying The proof details and a non-asymptotic result (similar to\r\n (19), the communication period can increase to \u03c4 = 1000 Theorem1butwithvariable\u03c4)areprovidedinAppendix. In\r\n which causes the training loss to diverge. ordertounderstandthemeaningofcondition(21),letus\ufb01rst\r\n Toavoid this issue, we propose the adaptive strategy given consider the case when \u03c4 = \u00b7\u00b7\u00b7 = \u03c4 is a constant. In this\r\n 0 R\r\n by (20) below. This strategy can also be justi\ufb01ed by theoret- case, the convergence condition is identical to mini-batch\r\n ical analysis. Suppose that in lth time interval, the objective\r\n 2Note that in the error analysis, the subscripts of communica-\r\n function has a local Lipschitz smoothness L . Then, by\r\n l tion period and learning rate represent the index of local update\r\n using the approximation \u03b7 L \u2248 1, which is common in\r\n l l periods rather than the index of the T -length wall-clock time\r\n 0\r\n SGDliterature (Balles et al., 2016), we derive the following intervals as considered in Sections 4.1-4.3.\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n SGD(Bottouetal., 2018): 25\r\n R R Computation time\r\n X\u03b7 \u2192\u221e,X\u03b72<\u221e. (23) 20 Communication time\r\n r r\r\n r=0 r=0 15\r\n Aslongasthecommunicationperiod sequence is bounded,\r\n it is trivial to adapt the learning rate scheme in mini-batch 10\r\n SGD(23)tosatisfy(21). Inparticular,whenthecommunica- Wall clock time\r\n tion period sequence is decreasing, the last two terms in (21) 5\r\n will become easier to be satis\ufb01ed and put less constraints on 0\r\n the learning rate sequence. ResNet50 ResNet50, =10 VGG16 VGG16, =10\r\n 5 EXPERIMENTAL RESULTS Figure 8. Wall-clock time to \ufb01nish 100 iterations in a cluster\r\n with 4 worker nodes. To achieve the same level communica-\r\n 5.1 Experimental Setting tion/computation ratio, VGG-16 requires larger communication\r\n Platform. The proposed adaptive communication scheme period than ResNet-50.\r\n was implemented in Pytorch (Paszke et al., 2017) with\r\n \u00b4\r\n Mpi4Py(Dalc\u0131n et al., 2005). All experiments were con- communicationschemewithfollowingmethodswitha\ufb01xed\r\n ducted on a local cluster with 4 worker nodes, each of which communication period: (1) Baseline: fully synchronous\r\n has an NVIDIATitanXGPUanda16-coreIntelXeonCPU. SGD (\u03c4 = 1); (2) Extreme high throughput case where\r\n Workernodesareconnectedviaa40Gbps(5000Mb/s)Eth- \u03c4 = 100; (3) Manually tuned case where a moderate value\r\n ernet interface. Due to space limitations, additional results of \u03c4 is selected after trial runs with different communication\r\n with 8 nodes are listed in Appendix A. periods. Instead of training for a \ufb01xed number of epochs, we\r\n Dataset. We evaluate our method for image classi\ufb01ca- train all methods for suf\ufb01ciently long time to convergence\r\n tion tasks on CIFAR10 and CIFAR100 dataset (Krizhevsky, and compare the training loss and test accuracy, both of\r\n 2009), which consists of 50,000 training images and 10,000 which are recorded after every 100 iterations.\r\n validation images in 10 and 100 classes respectively. Each\r\n worker machine is assigned with a partition which will be 5.2 Adaptive CommunicationinPASGD\r\n randomly shuf\ufb02ed after every epoch.\r\n Model. Wechoosetotrain deep neural networks VGG-16 We\ufb01rstvalidatetheeffectiveness of ADACOMM whichuses\r\n (Simonyan & Zisserman, 2014) and ResNet-50 (He et al., the communication period update rule (18) combined with\r\n 2016) from scratch 3. These two neural networks have dif- (20) on original PASGD without momentum.\r\n ferent architectures and parameter sizes, thus resulting in Figure 9 presents the results for VGG-16 for both \ufb01xed and\r\n different performance of periodic-averaging. As shown in variable learning rates. A large communication period \u03c4\r\n Figure 8, for VGG-16, the communication time is about 4 initially results in a rapid drop in the error, but the error \ufb01-\r\n times higher than the computation time. Thus, compared nally converges to higher \ufb02oor. By adapting \u03c4, the proposed\r\n to ResNet-50, it requires a larger \u03c4 in order to reduce the ADACOMMschemestrikesthebesterror-runtime trade-off\r\n runtime-per-iteration and achieve fast convergence. Sim- in all settings. In Figure 9a, while fully synchronous SGD\r\n \u22123\r\n ilar high communication/computation ratio is common in takes 33.5 minutes to reach 4 \u00d7 10 training loss, ADA-\r\n literature, see (Lin et al., 2018; Harlap et al., 2018). COMMcosts15.5minutesachievingmorethan2\u00d7speedup.\r\n Similarly, in Figure 9b, ADACOMM takes 11.5 minutes to\r\n HyperparameterChoice. Mini-batchsizeoneachworker reach 4.5\u00d710\u22122 training loss achieving 3.3\u00d7 speedup over\r\n is 128. Therefore, the total mini-batch size per iteration is fully synchronous SGD (38.0 minutes).\r\n 512. The initial learning rates for VGG-16 and ResNet-50\r\n are 0.2 and 0.4 respectively. The weight decay for both However, for ResNet-50, the communication overhead is\r\n networks is 0.0005. In the variable learning rate setting, we nolonger the bottleneck. For \ufb01xed communication period,\r\n th th th th the negative effect of performing local updates becomes\r\n decay the learning rate by 10 after 80 /120 /160 /200\r\n epochs. We set the time interval length T as 60 seconds moreobviousandcancelsthebene\ufb01toflowcommunication\r\n 0 delay (see Figures 10b and 10c). It is not surprising to\r\n (about 10 epochs for the initial communication period).\r\n Metrics. Wecomparetheperformanceofproposedadaptive see fully synchronous SGD is nearly the best one in the\r\n error-runtime plot among all \ufb01xed-\u03c4 methods. Even in\r\n 3The implementations of VGG-16 and ResNet-50 follow this this extreme case, adaptive communication can still have\r\n GitHub repository: https://github.com/meliketoy/ a competitive performance. When combined with learning\r\n wide-resnet.pytorch rate decay, the adaptive scheme is about 1.3 times faster\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n than fully synchronous SGD (see Figure 10a, 15.0 versus Table 1. Best test accuracies on CIFAR10 in different settings\r\n 21.5 minutes to achieve 3 \u00d7 10\u22122 training loss). (SGDwithoutmomentum).\r\n Table 1 lists the test accuracies in different settings; we MODEL METHODS FIXED LR VARIABLE LR\r\n report the best accuracy within a time budget for each setting.\r\n Theresults show that adaptive communication method have \u03c4 = 1 90.5 92.75\r\n better generalization than fully synchronous SGD. In the VGG-16 \u03c4 = 20 92.25 92.5\r\n variable learning rate case, the adaptive method even gives \u03c4 = 100 92.0 92.4\r\n the better test accuracy than PASGD with the best \ufb01xed \u03c4. ADACOMM 91.1 92.85\r\n \u03c4 = 1 88.76 92.26\r\n 5.3 Adaptive CommunicationinMomentumSGD RESNET- \u03c4 = 5 90.42 92.26\r\n 50 \u03c4 = 100 88.66 91.8\r\n Theadaptive communication scheme is proposed based on ADACOMM 89.57 92.42\r\n the joint error-runtime analysis for PASGD without mo-\r\n mentum. However, it can also be extended to other SGD\r\n variants, and in this subsection, we show that the proposed momentumandwesimplyfollowthecommonpracticeset-\r\n method works well for SGD with momentum. ting the momentum factor as 0.9.\r\n 5.3.1 Block Momentuminperiodic-averaging 5.3.2 ADACOMMplusBlockMomentum\r\n Before presenting the empirical results, we describe how to In Figure 11, we apply our adaptive communication strategy\r\n introducemomentuminPASGD.Anaivewayistoapplythe in PASGDwithblock momentumandobservesigni\ufb01cant\r\n momentumindependently to each local model, where each performance gain on CIFAR10/100. In particular, the adap-\r\n worker maintains an independent momentum buffer, which tive communication scheme has the fastest convergence\r\n is the latest change in the parameter vector x. However, this rate with respect to wall-clock time in the whole training\r\n does not account for the potential dramatic change in x at process. While fully synchronous SGD gets stuck with a\r\n each averaging step. When local models are synchronized, plateau before the \ufb01rst learning rate decay, the training loss\r\n the local momentum buffer will contain the update steps of adaptive method continuously decreases until converging.\r\n before averaging, resulting in a large momentum term in For VGG-16inFigure11b, ADACOMM is3.5\u00d7faster(in\r\n the \ufb01rst SGD step of the each local update period. When \u03c4 terms of wall-clock time) than fully synchronous SGD in\r\n is large, this large momentum term can side-track the SGD \u22123\r\n reaching a 3 \u00d7 10 training loss. For ResNet-50 in Fig-\r\n descent direction resulting in slower convergence. \u22122\r\n ure 11a, ADACOMM takes 15.8 minutes to get 2 \u00d7 10\r\n To address this issue, a block momentum scheme was pro- training loss which is 2 times faster than fully synchronous\r\n posedin(Chen&Huo,2016)andappliedtospeechrecogni- SGD(32.6minutes).\r\n tion tasks. The basic idea is to treat the local updates in each\r\n communication period as one big gradient step between two 6 CONCLUDINGREMARKS\r\n synchronized models, and to introduce a global momentum Thedesignofcommunication-ef\ufb01cient SGDalgorithmsthat\r\n for this big accumulated step. The update rule can be written are robust to system variability is vital to scaling machine\r\n as follows in terms of the momentum u :\r\n j learning training to resource-limited computing nodes. This\r\n u =\u03b2 u +G (24) paper is one of the \ufb01rst to analyze the convergence of er-\r\n j glob j\u22121 j ror with respect to wall-clock time instead of number of\r\n x =x \u2212\u03b7 u (25)\r\n (j+1)\u03c4+1 j\u03c4+1 j j iterations by accounting for the effect of computation and\r\n P P (i)\r\n where G = 1 m \u03c4 g(x ) represents the accu- communication delays on the runtime per iteration. We\r\n j m i=1 k=1 j\u03c4+k present a theoretical analysis of the error-runtime trade-off\r\n mulated gradients in the jth local update period and \u03b2glob\r\n denotes the global momentum factor. Moreover, workers for periodic-averaging SGD (PASGD), where each node\r\n can also conduct momentum SGDonlocalmodels,buttheir performs local updates and their models are averaged after\r\n local momentum buffer will be cleared at the beginning of every \u03c4 iterations. Based on the joint error-runtime analysis,\r\n each local update period. That is, we restart momentum wedesignthe\ufb01rst (to the best of our knowledge) adaptive\r\n SGDonlocalmodelsaftereveryaveraging step. The same communication strategy called ADACOMM for distributed\r\n strategy was also suggested in Microsoft\u2019s CNTK frame- deep learning. Experimental results using VGGNet and\r\n work (Seide & Agarwal, 2016). In our experiments, we ResNet show that the proposed method can achieve up to a\r\n set the global momentum factor as 0.3 and local momen- 3\u00d7improvementinruntime,whileachieving the same error\r\n tum factor as 0.9 following (Lin et al., 2018). In the fully \ufb02oor as fully synchronous SGD.\r\n synchronous case, there is no need to introduce the block Goingbeyondperiodic-averagingSGD,ourideaofadapting\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n VGG-16, variable learning rate VGG-16, fixed learning rate VGG16, fixed learning rate, CIFAR100\r\n 100 = 1 = 1 = 100\r\n = 20 100 = 20 = 20\r\n = 100 = 100 = 1\r\n AdaComm AdaComm 100 AdaComm\r\n 2x less 3.3x less\r\n Training loss10-2 Training loss10-1 Training loss\r\n 10-1\r\n 0 5 10 15 20 25 30 35 0 10 20 30 40 50 60 0 5 10 15 20 25 30\r\n Wall clock time / min Wall clock time / min Wall clock time / min\r\n 20 20 20\r\n 10 10 10\r\n 00 5 10 15 20 25 30 35 0 00 5 10 15 20 25 30\r\n Comm. Period Comm. Period01020 30 40 50 60 Comm. Period\r\n (a) Variable learning rate on CIFAR10. (b) Fixed learning rate on CIFAR10. (c) Fixed learning rate on CIFAR100.\r\n Figure 9. ADACOMM on VGG-16: Achieves 3.3\u00d7 speedup over fully synchronous SGD (in (b), 11.5 versus 38.0 minutes to achieve\r\n \u22122\r\n 4.5 \u00d710 training loss).\r\n ResNet-50, variable learning rate ResNet-50, fixed learning rate ResNet-50, fixed learning rate, CIFAR100\r\n = 1 = 1 = 1 4 = 1\r\n 100 = 5 = 5 = 5 3 = 5\r\n = 100 101000 = 100 = 100 = 100\r\n AdaComm AdaComm 2 AdaComm\r\n 10-1 1.4x less\r\n Training loss Training lossTraining loss Training loss1\r\n 10-2\r\n 0 5 10 15 20 25 00 55 1010 1515 2020 2525 0 5 10 15 20 25\r\n Wall clock time / min Wall clock time / minWall clock time / min Wall clock time / min\r\n 10 10 10\r\n 5 5 5\r\n 00 5 10 15 20 25 00 5 10 15 20 25 0\r\n Comm. Period Comm. Period Comm. Period05 10 15 20 25\r\n (a) Variable learning rate on CIFAR10. (b) Fixed learning rate on CIFAR10. (c) Fixed learning rate on CIFAR100.\r\n \u22122\r\n Figure 10. ADACOMM on ResNet-50: Achieves 1.4\u00d7 speedup over Sync SGD (in (a), 15.0 versus 21.5 minutes to achieve 3 \u00d7 10\r\n training loss).\r\n 2.5 ResNet-50 with block momentum 5ResNet-50 w/ block momentum, CIFAR100\r\n = 1 100 = 1 = 1\r\n 2 = 20 = 20 4 = 20\r\n = 100 = 100 = 100\r\n 1.5 AdaComm AdaComm 3 AdaComm\r\n 1 2\r\n Training loss Training loss10-2 3.5x less Training loss\r\n 0.5 1\r\n 00 10 20 30 40 0 10 20 30 40 50 60 70 00 10 20 30 40\r\n Wall clock time / min Wall clock time / min Wall clock time / min\r\n 10 20 20\r\n 5 10 10\r\n 0 00 10 20 30 40 50 60 70 0\r\n Comm. Period010 20 30 40 Comm. Period Comm. Period010 20 30 40\r\n (a) ResNet-50 on CIFAR10. (b) VGG-16 on CIFAR10. (c) ResNet-50 on CIFAR100.\r\n Figure 11. ADACOMM with block momentum achieves 3.5\u00d7 speedup over Sync SGD (in (b), 19.0 versus 66.7 minutes to achieve\r\n \u22123\r\n 3\u00d710 trainingloss).\r\n frequency of averaging distributed SGD updates can be ACKNOWLEDGMENTS\r\n easily extended to other SGD frameworks including elastic- TheauthorsthankProf. GregGangerforhelpfuldiscussions.\r\n averaging (Zhang et al., 2015), decentralized SGD (e.g., This work was partially supported by NSF CCF-1850029\r\n adapting network sparsity) (Lian et al., 2017) and parameter and an IBM Faculty Award. Experiments were conducted\r\n server-based training (e.g., adapting asynchrony). onclusters provided by the Parallel Data Lab at CMU.\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n REFERENCES Gupta, S., Zhang, W., and Wang, F. Model accuracy and\r\n Balles, L., Romero, J., and Hennig, P. Coupling adap- runtime tradeoff in distributed deep learning: A system-\r\n tive batch sizes with learning rates. arXiv preprint atic study. In IEEE 16th International Conference on\r\n arXiv:1612.05086, 2016. DataMining(ICDM),pp.171\u2013180.IEEE,2016.\r\n Bottou, L., Curtis, F. E., and Nocedal, J. Optimization Harlap, A., Narayanan, D., Phanishayee, A., Seshadri, V.,\r\n methods for large-scale machine learning. SIAM Review, Devanur, N., Ganger, G., and Gibbons, P. Pipedream:\r\n 60(2):223\u2013311, 2018. Fast and ef\ufb01cient pipeline parallel dnn training. arXiv\r\n preprint arXiv:1806.03377, 2018.\r\n Chaudhari, P., Baldassi, C., Zecchina, R., Soatto, S., Tal- He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learn-\r\n walkar, A., and Oberman, A. Parle: parallelizing stochas- ing for image recognition. In Proceedings of the IEEE\r\n tic gradient descent. arXiv preprint arXiv:1707.00424, conference on computer vision and pattern recognition,\r\n 2017. pp. 770\u2013778, 2016.\r\n Chen,K.andHuo,Q. Scalabletrainingofdeeplearningma- Iandola, F. N., Moskewicz, M. W., Ashraf, K., and Keutzer,\r\n chines by incremental block training with intra-block par- K. Firecaffe: near-linear acceleration of deep neural\r\n allel optimization and blockwise model-update \ufb01ltering. network training on compute clusters. In Proceedings of\r\n In Acoustics, Speech and Signal Processing (ICASSP), the IEEE Conference on Computer Vision and Pattern\r\n 2016 IEEE International Conference on, pp. 5880\u20135884. Recognition, pp. 2592\u20132600, 2016.\r\n IEEE, 2016.\r\n Cui, H., Cipar, J., Ho, Q., Kim, J. K., Lee, S., Kumar, A., Jiang, Z., Balu, A., Hegde, C., and Sarkar, S. Collaborative\r\n Wei, J., Dai, W., Ganger, G. R., Gibbons, P. B., et al. Ex- deep learning in \ufb01xed topology networks. In Advances in\r\n ploiting bounded staleness to speed up big data analytics. Neural Information Processing Systems, pp. 5906\u20135916,\r\n In 2014 USENIX Annual Technical Conference (USENIX 2017.\r\n ATC14),pp.37\u201348,2014. Krizhevsky, A. Learning multiple layers of features from\r\n \u00b4 tiny images. Technical report, Citeseer, 2009.\r\n Dalc\u0131n, L., Paz, R., and Storti, M. MPI for python. Journal\r\n of Parallel and Distributed Computing, 65(9):1108\u20131115, Li, M., Andersen, D. G., Park, J. W., Smola, A. J., Ahmed,\r\n 2005. A., Josifovski, V., Long, J., Shekita, E. J., and Su, B.-Y.\r\n David, H. A. and Nagaraja, H. N. Order statistics. John Scaling distributed machine learning with the parameter\r\n Wiley, Hoboken, N.J., 2003. server. In OSDI, volume 14, pp. 583\u2013598, 2014.\r\n Dean, J., Corrado, G., Monga, R., Chen, K., Devin, M., Lian, X., Zhang, C., Zhang, H., Hsieh, C.-J., Zhang, W., and\r\n Mao,M.,Senior, A., Tucker, P., Yang, K., Le, Q. V., et al. Liu, J. Can decentralized algorithms outperform central-\r\n Large scale distributed deep networks. In Advances in ized algorithms? a case study for decentralized parallel\r\n neural information processing systems, pp. 1223\u20131231, stochastic gradient descent. In Advances in Neural Infor-\r\n 2012. mation Processing Systems, pp. 5336\u20135346, 2017.\r\n Dekel, O., Gilad-Bachrach, R., Shamir, O., and Xiao, L. Lin, T., Stich, S. U., and Jaggi, M. Don\u2019t use large mini-\r\n Optimal distributed online prediction using mini-batches. batches, use local SGD. arXivpreprintarXiv:1808.07217,\r\n Journal of Machine Learning Research, 13(Jan):165\u2013202, 2018.\r\n 2012. McMahan,H.B.,Moore,E.,Ramage,D.,Hampson,S.,etal.\r\n Dutta, S., Joshi, G., Ghosh, S., Dube, P., and Nagpurkar, Communication-ef\ufb01cient learning of deep networks from\r\n P. Slow and stale gradients can win the race: Error- decentralized data. arXiv preprint arXiv:1602.05629,\r\n runtime trade-offs in distributed SGD. arXiv preprint 2016.\r\n arXiv:1803.01113, 2018. \u00b4\r\n Mitliagkas, I., Zhang, C., Hadjis, S., and Re, C. Asynchrony\r\n Ghadimi, S. and Lan, G. Stochastic \ufb01rst-and zeroth-order begets momentum, with an application to deep learning.\r\n methods for nonconvex stochastic programming. SIAM In 54th Annual Allerton Conference on Communication,\r\n Journal on Optimization, 23(4):2341\u20132368, 2013. Control, and Computing (Allerton), pp. 997\u20131004. IEEE,\r\n \u00b4 2016.\r\n Goyal, P., Dollar, P., Girshick, R., Noordhuis, P.,\r\n Wesolowski, L., Kyrola, A., Tulloch, A., Jia, Y., and He, Moritz, P., Nishihara, R., Stoica, I., and Jordan, M. I.\r\n K. Accurate, large minibatch SGD: training ImageNet in SparkNet: Training deep networks in spark. arXiv\r\n 1 hour. arXiv preprint arXiv:1706.02677, 2017. preprint arXiv:1511.06051, 2015.\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E.,\r\n DeVito, Z., Lin, Z., Desmaison, A., Antiga, L., and Lerer,\r\n A. Automatic differentiation in pytorch. In NIPS-W,\r\n 2017.\r\n Seide, F. and Agarwal, A. CNTK: Microsoft\u2019s open-source\r\n deep-learning toolkit. In Proceedings of the 22nd ACM\r\n SIGKDDInternational Conference on Knowledge Dis-\r\n covery and Data Mining, pp. 2135\u20132135. ACM, 2016.\r\n Simonyan, K. and Zisserman, A. Very deep convolu-\r\n tional networks for large-scale image recognition. arXiv\r\n preprint arXiv:1409.1556, 2014.\r\n Smith, S. L., Kindermans, P.-J., and Le, Q. V. Don\u2019t decay\r\n the learning rate, increase the batch size. arXiv preprint\r\n arXiv:1711.00489, 2017a.\r\n Smith, V., Chiang, C.-K., Sanjabi, M., and Talwalkar, A. S.\r\n Federated multi-task learning. In Advances in Neural\r\n Information Processing Systems, pp. 4424\u20134434. 2017b.\r\n Stich, S. U. Local SGD converges fast and communicates\r\n little. arXiv preprint arXiv:1805.09767, 2018.\r\n Su,H.andChen,H.Experimentsonparalleltrainingofdeep\r\n neural network using model averaging. arXiv preprint\r\n arXiv:1507.01239, 2015.\r\n Wang, J. and Joshi, G. Cooperative SGD: A\r\n uni\ufb01ed framework for the design and analysis of\r\n communication-ef\ufb01cient SGD algorithms. arXiv preprint\r\n arXiv:1808.07576, 2018.\r\n Yu, H., Yang, S., and Zhu, S. Parallel restarted SGD for\r\n non-convexoptimizationwithfasterconvergenceandless\r\n communication. arXiv preprint arXiv:1807.06629, 2018.\r\n \u00b4\r\n Zhang, J., De Sa, C., Mitliagkas, I., and Re, C. Paral-\r\n lel SGD: When does averaging help? arXiv preprint\r\n arXiv:1606.07365, 2016.\r\n Zhang, S., Choromanska, A. E., and LeCun, Y. Deep learn-\r\n ing with elastic averaging SGD. In NIPS\u201915 Proceedings\r\n of the 28th International Conference on Neural Informa-\r\n tion Processing Systems, pp. 685\u2013693, 2015.\r\n Zhou, F. and Cong, G. On the convergence properties\r\n of a k-step averaging stochastic gradient descent al-\r\n gorithm for nonconvex optimization. arXiv preprint\r\n arXiv:1708.01012, 2017.\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n A ADDITIONALEXPERIMENTALRESULTS\r\n In the 8 worker case, the communication among nodes is accomplished via Nvidia Collective Communication Library\r\n (NCCL). The mini-batch size on each node is 64. The initial learning rate is set as 0.2 for both VGG-16 and ResNet-50.\r\n \u22122\r\n In Figure 12a, while fully synchronous SGD takes 17.5 minutes to reach 10 training loss, ADACOMM only costs 6.0\r\n minutes achieving about 2.9\u00d7 speedup.\r\n VGG-16, variable LR, 8 nodes VGG-16, fixed LR, CIFAR100, 8 nodes\r\n 100 = 1\r\n = 20\r\n = 100\r\n 0 AdaComm\r\n 10\r\n Training loss-2 2.9x less Training loss\r\n 10\r\n 0 5 10 15 20 0 5 10 15\r\n Wall-clock time / min Wall-clock time / min\r\n 50\r\n 50\r\n 00 5 10 15 20 00 5 10 15\r\n Comm. Period Comm. Period\r\n (a) Variable learning rate on CIFAR10. (b) Fixed learning rate on CIFAR100.\r\n Figure 12. ADACOMM on VGG-16with8workers: Achieves 2.9\u00d7 speedup over Sync SGD (in (a), 6.0 versus 17.5 minutes to achieve\r\n \u22122\r\n 1\u00d710 trainingloss). Test accuracies at convergence when using variable learning rate: 92.52% (\u03c4 = 1), 91.85% (\u03c4 = 20), 91.15%\r\n (\u03c4 = 100), and 92.72% (AdaComm).\r\n ResNet-50, variable LR, 8 nodes ResNet-50, fixed LR, CIFAR100, 8 nodes\r\n = 1 4 = 1\r\n 100 = 10 3 = 10\r\n = 100 = 100\r\n AdaComm 2 AdaComm\r\n 10-1 1.6x less\r\n Training loss Training loss\r\n 1\r\n 0 5 10 15 20 25 0 5 10 15\r\n Wall-clock time / min Wall-clock time / min\r\n 10 10\r\n 5 5\r\n 00 5 10 15 20 25 0\r\n Comm. Period Comm. Period0 5 10 15\r\n (a) Variable learning rate on CIFAR10. (b) Fixed learning rate on CIFAR100.\r\n Figure 13. ADACOMM on ResNet-50 with 8 workers: Achieves 1.6\u00d7 speedup over Sync SGD (in (a), 11.15 versus 18.25 minutes to\r\n \u22121\r\n achieve 1 \u00d7 10 training loss). Test accuracies at convergence when using variable learning rate: 91.93% (\u03c4 = 1), 91.51% (\u03c4 = 10),\r\n 90.46%(\u03c4 = 100), and 91.77% (AdaComm).\r\n B INEFFICIENT LOCAL UPDATES\r\n It is worth noting there is an interesting phenomenon about the convergence of periodic averaging SGD (PASGD). When the\r\n learning rate is \ufb01xed, PASGD with \ufb01ne-tuned communication period has better test accuracy than both fully synchronous\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n SGDandtheadaptivemethod,whileits training loss remains higher than the latter two methods (see Figure 9, Figure 10).\r\n In particular, on CIFAR100 dataset, we observe about 5% improvement in test accuracy when \u03c4 = 5. To investigate this\r\n phenomenon, we evaluate the test accuracy for PASGD (\u03c4 = 15) in two frequencies: 1) every 135 iterations; 2) every 100\r\n iterations. In the former case, the test accuracy is reported just after the averaging step. However, in the latter case, the test\r\n accuracy can come from either the synchronized/averaged model or local models, since 100 cannot be divided by 15.\r\n FromFigure 14, it is clear that local model\u2019s accuracy is much lower than the synchronized model, even when the algorithm\r\n has converged. Thus, we conjecture that the improvement of test accuracy only happens on the synchronized model. That is,\r\n after averaging, the test accuracy will undergo a rapid increase but it decreases again in the following local steps due to noise\r\n in stochastic gradients. Such behavior may depend on the geometric structure of the loss surface of speci\ufb01c neural networks.\r\n Theobservation also reveals that the local updates are inef\ufb01cient as they reduces the accuracy and makes no progress. In this\r\n sense, it is necessary for PASGD to reduce the gradient variance by either decaying learning rate or decaying communication\r\n period.\r\n 100\r\n 80\r\n 60 ~ 10 % variation\r\n Test accuracy40\r\n Evaluated every 100 iterations\r\n 20 Evaluated every 15 9 iterations\r\n 0 50 100 150\r\n Epochs\r\n Figure 14. PASGD (\u03c4 = 15) with ResNet-50 on CIFAR10 (\ufb01xed learning rate, no momentum). There exists about 10% accuracy gap\r\n between local models and the synchronized model.\r\n C ASSUMPTIONSFORCONVERGENCEANALYSIS\r\n The convergence analysis is conducted under the following assumptions, which are similar to the assumptions made\r\n in previous work on the analysis of PASGD (Zhou & Cong, 2017; Yu et al., 2018; Wang & Joshi, 2018; Stich, 2018).\r\n In particular, we make no assumptions on the convexity of the objective function. We also remove the uniform bound\r\n assumption for the norm of stochastic gradients.\r\n Assumption 1 (Lipschitz smooth & lower bound on F). The objective function F(x) is differentiable and L-Lipschitz\r\n smooth, i.e., k\u2207F(x) \u2212 \u2207F(y)k \u2264 Lkx\u2212yk. The function value is bounded below by a scalar Finf.\r\n Assumption2(Unbiasedestimation). The stochastic gradient evaluated on a mini-batch \u03be is an unbiased estimator of the\r\n full batch gradient E [g(x)] = \u2207F(x).\r\n \u03be|x\r\n Assumption3(Boundedvariance). The variance of stochastic gradient evaluated on a mini-batch \u03be is bounded as\r\n 2 2 2\r\n E kg(x)\u2212\u2207F(x)k \u2264\u03b2k\u2207F(x)k +\u03c3\r\n \u03be|x\r\n where \u03b2 and \u03c32 are non-negative constants and in inverse proportion to the mini-batch size.\r\n D PROOFOFTHEOREM2: ERROR-RUNTIMECONVERGENCEOFPASGD\r\n Firstly, let us recall the error-analysis of PASGD. We adapt the theorem from (Wang & Joshi, 2018).\r\n Lemma1(Error-ConvergenceofPASGD(Wang&Joshi,2018)). ForPASGD,underAssumptions1to3,ifthelearning\r\n rate satis\ufb01es \u03b7L + \u03b72L2\u03c4(\u03c4 \u2212 1) \u2264 1 and all workers are initialized at the same point x1, then after K iterations, we have\r\n \u0014 \u0015 \" K # 2\r\n 1 X 2[F(x )\u2212F ] \u03b7L\u03c3\r\n 2 2 1 inf 2 2 2\r\n E min k\u2207F(xk)k \u2264E K k\u2207F(xk)k \u2264 \u03b7K + m +\u03b7 L \u03c3 (\u03c4\u22121) (26)\r\n k\u2208[1,K] k=1\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n where L is the Lipschtiz constant of the objective function, \u03c32 is the variance bound of mini-batch stochastic gradients and\r\n xk denotes the averaged model at the kth iteration.\r\n Fromtheruntime analysis in Section 2, we know that the expected runtime per iteration of PASGD is\r\n E[T ] = Y + D. (27)\r\n P-Avg \u03c4\r\n Accordingly, the total wall-clock time of training K iteration is\r\n T =K\u0012Y +D\u0013. (28)\r\n \u03c4\r\n Then, directly substituting K = T/E[T ] in (26), we complete the proof.\r\n P-Avg\r\n E PROOFOFTHEOREM3: THEBESTCOMMUNICATIONPERIOD\r\n Taking the derivative of the upper bound (14) with respect to the communication period, we obtain\r\n 2[F(x )\u2212F ]E[D]\r\n 1 inf 2 2 2\r\n \u2212 \u03b7T \u03c42 +\u03b7 L \u03c3 . (29)\r\n Whenthederivative equals to zero, the communication period is\r\n s2(F(x )\u2212F )E[D]\r\n \u03c4\u2217 = 1 inf . (30)\r\n \u03b73L2\u03c32T\r\n Since the second derivative of (14) is\r\n 4[F(x1)\u2212Finf]E[D] > 0, (31)\r\n \u03b7T \u03c43\r\n then the optimal value obtained in (30) must be a global minimum.\r\n F PROOFOFTHEOREM4: ERROR-CONVERGENCEOFADAPTIVECOMMUNICATION\r\n SCHEME\r\n F.1 Notations\r\n In order to faciliate the analysis, we would like to \ufb01rst introduce some useful notations. De\ufb01ne matrices X ,G \u2208 Rd\u00d7m\r\n k k\r\n that concatenate all local models and gradients:\r\n X =[x(1),...,x(m)], (32)\r\n k k k\r\n (1) (m)\r\n G =[g(x ),...,g(x )]. (33)\r\n k k k\r\n \u22a4 \u22a4 \u22a4\r\n Besides, de\ufb01ne matrix J = 11 /(1 1) where 1 denotes the column vector [1,1,...,1] . Unless otherwise stated, 1 is a\r\n size m column vector, and the matrix J and identity matrix I are of size m \u00d7 m, where m is the number of workers.\r\n F.2 Proof\r\n Let us \ufb01rst focus on the j-th local update period, where j \u2208 {0,1,...,R}. Without loss of generality, suppose the local\r\n index of the jth local update period starts from 1 and ends with \u03c4 . Then, for the k-th local step in the interested period, we\r\n j\r\n have the following lemma.\r\n Lemma2(Lemma1in(Wang&Joshi,2018)). ForPASGD,underAssumptions1to3,atthek-thiteration,wehavethe\r\n following bound for the objective value:\r\n \u0014 \u0012 \u0013\u0015 2 \u03b72L\u03c32\r\n \u03b7j 2 \u03b7j \u03b2 k\u2207F(Xk)k j\r\n E [F(x )] \u2212 F(x ) \u2264\u2212 k\u2207F(x )k \u2212 1\u2212\u03b7 L +1 \u00b7 F + +\r\n k k+1 k 2 k 2 j m m 2m\r\n \u03b7 L2\r\n j 2\r\n kXk(I\u2212J)k (34)\r\n 2m F\r\n where xk denotes the averaged model at the kth iteration.\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n Taking the total expectation and summing over all iterates in the j-th local update period, we can obtain\r\n \u03c4 \u0014 \u0012 \u0013\u0015 \u03c4\r\n j j 2 2 2\r\n \u0002 \u0003 \u03b7 X \u03b7 \u03b2 XEk\u2207F(X )k \u03b7 L\u03c3 \u03c4\r\n j 2 j k F j j\r\n E F(x\u03c4j+1)\u2212F(x1) \u2264\u2212 2 Ek\u2207F(xk)k \u2212 2 1\u2212\u03b7jL m +1 \u00b7 m + 2m +\r\n k=1 k=1\r\n 2 \u03c4j\r\n \u03b7jL X 2\r\n EkX (I\u2212J)k . (35)\r\n 2m k F\r\n k=1\r\n Next, we are going to provide an upper bound for the last term in (35). Note that\r\n X (I\u2212J)=X (I \u2212J)\u2212\u03b7 G (I \u2212J) (36)\r\n k k\u22121 j k\u22121\r\n =X (I\u2212J)\u2212\u03b7 G (I\u2212J)\u2212\u03b7 G (I\u2212J) (37)\r\n k\u22122 j k\u22122 j k\u22121\r\n k\u22121\r\n =X(I\u2212J)\u2212\u03b7 XG (I\u2212J) (38)\r\n 1 j r\r\n r=1\r\n k\u22121\r\n =\u2212\u03b7 XG(I\u2212J) (39)\r\n j r\r\n r=1\r\n where (39) follows the fact that all workers start from the same point at the beginning of each local update period, i.e.,\r\n X1(I\u2212J)=0.Accordingly,wehave\r\n \uf8ee \uf8f9\r\n \r \r\r\n k\u22121 2\r\n h i \rX \r\r\n 2 2 \r \r\r\n E kX (I\u2212J)k =\u03b7 E\uf8f0 G(I\u2212J) \uf8fb (40)\r\n k F j \r r \r\r\n \rr=1 \r\r\n \uf8ee \uf8f9 F \uf8ee \uf8f9\r\n \r \r \r \r\r\n k\u22121 2 m k\u22121 2\r\n \rX \r X \rX \r\r\n 2 \r \r 2 \r (i) \r\r\n \u2264\u03b7 E\uf8f0 G \uf8fb=\u03b7 E\uf8f0 g(x ) \uf8fb (41)\r\n j \r r\r j \r r \r\r\n \rr=1 \r i=1 \rr=1 \r\r\n F\r\n 2 2 2\r\n wheretheinequality(41)isduetotheoperatornormof(I\u2212J)islessthan1. Furthermore,usingthefact(a+b) \u2264 2a +2b ,\r\n one can get\r\n \uf8ee \uf8f9\r\n \r \r\r\n m k\u22121 k\u22121 2\r\n h i X \rX\u0010 \u0011 X \r\r\n 2 2 \r (i) (i) (i) \r\r\n E kX (I\u2212J)k \u2264\u03b7 E\uf8f0 g(x )\u2212\u2207F(x ) + \u2207F(x ) \uf8fb (42)\r\n k F j \r r r r \r\r\n i=1 \rr=1 r=1 \r\r\n \uf8ee\r \r \uf8f9 \uf8ee\r \r \uf8f9\r\n m k\u22121 2 m k\u22121 2\r\n X \rX\u0010 \u0011\r X \rX \r\r\n \u22642\u03b72 E\uf8f0\r g(x(i)) \u2212 \u2207F(x(i)) \r \uf8fb+2\u03b72 E\uf8f0\r \u2207F(x(i))\r \uf8fb. (43)\r\n j \r r r \r j \r r \r\r\n i=1 \rr=1 \r i=1 \rr=1 \r\r\n | {z } | {z }\r\n T1 T2\r\n Forthe\ufb01rsttermT ,sincethestochasticgradientsareunbiased,allcrosstermsarezero. Thus,combiningwithAssumption3,\r\n 1\r\n wehave\r\n m k\u22121 \u0014 \u0015\r\n XX \r \r\r\n 2\r\n 2 \r (i) (i) \r\r\n T =2\u03b7 E g(x )\u2212\u2207F(x ) (44)\r\n 1 j i=1 r=1\u0014 \r\u0014 r \u0015r \r \u0015\r\n m k\u22121 \r \r\r\n XX 2\r\n 2 \r (i) \r 2\r\n \u22642\u03b7 \u03b2E \u2207F(x ) +\u03c3 (45)\r\n j i=1 r=1 \r r \r\r\n k\u22121 h i\r\n 2 X 2 2 2\r\n =2\u03b7 \u03b2 E k\u2207F(X )k +2\u03b7 m(k\u22121)\u03c3 . (46)\r\n j r F j\r\n r=1\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n For the second term in (43), directly applying Jensen\u2019s inequality, we get\r\n m k\u22121 \u0014 \u0015\r\n XX \r \r\r\n 2\r\n 2 \r (i) \r\r\n T \u22642\u03b7 (k\u22121) E \u2207F(x ) (47)\r\n 2 j i=1 r=1 \r r \r\r\n k\u22121 h i\r\n =2\u03b72(k\u22121)XE k\u2207F(X )k2 . (48)\r\n j r F\r\n r=1\r\n Substituting the bounds of T and T into (43),\r\n 1 2\r\n h i k\u22121 h i\r\n 2 2 X 2 2 2\r\n E kXk(I\u2212J)k \u22642\u03b7 [\u03b2+(k\u22121)] E k\u2207F(Xr)k +2\u03b7 m(k\u22121)\u03c3 . (49)\r\n F j F j\r\n r=1\r\n Recall the upper bound (35), we further derive the following bound:\r\n \u03c4 \u03c4 \" # \u03c4\r\n j h i j k\u22121 h i j\r\n X 2 2 X X 2 2 2 X\r\n E kX (I\u2212J)k \u22642\u03b7 [\u03b2 +(k \u22121)] E k\u2207F(X )k +2\u03b7 m\u03c3 (k \u22121) (50)\r\n k F j r F j\r\n k=1 k=1\" r=1 # k=1\r\n \u03c4j k\u22121 h i\r\n =2\u03b72X [\u03b2+(k\u22121)]XE k\u2207F(X )k2 +\u03b72m\u03c32\u03c4 (\u03c4 \u22121) (51)\r\n j r F j j j\r\n k=1\" r=1 #\r\n \u03c4j k\u22121 h i\r\n =2\u03b72X [\u03b2+(k\u22121)]XE k\u2207F(X )k2 +\u03b72m\u03c32\u03c4 (\u03c4 \u22121). (52)\r\n j r F j j j\r\n k=2 r=1\r\n P h i P h i\r\n k\u22121 2 \u03c4 \u22121 2\r\n Then, since E k\u2207F(X )k \u2264 j E k\u2207F(X )k ,wehave\r\n r=1 r F r=1 r F\r\n \u03c4 \u03c4 \u22121 \u03c4\r\n j h i j h i j\r\n X 2 2 X 2 X 2 2\r\n E kX (I\u2212J)k \u22642\u03b7 E k\u2207F(X )k [\u03b2 +(k \u22121)]+\u03b7 m\u03c3 \u03c4 (\u03c4 \u22121) (53)\r\n k F j r F j j j\r\n k=1 r=1 k=2\r\n \u03c4 \u22121\r\n j h i\r\n 2 X 2 2 2\r\n =\u03b7 E k\u2207F(X )k [2\u03b2(\u03c4 \u22121)+\u03c4 (\u03c4 \u22121)]+\u03b7 m\u03c3 \u03c4 (\u03c4 \u22121). (54)\r\n j r F j j j j j j\r\n r=1\r\n Plugging (54) into (35),\r\n \u03c4 \u0014 \u0012 \u0013\u0015 \u03c4 h 2i\r\n j j E k\u2207F(X )k 2 2\r\n \u0002 \u0003 \u03b7 X \u03b7 \u03b2 X k F \u03b7 L\u03c3 \u03c4\r\n j 2 j j j\r\n E F(x ) \u2212F(x ) \u2264\u2212 Ek\u2207F(x )k \u2212 1\u2212\u03b7 L +1 \u00b7 + +\r\n \u03c4 +1 1 k j\r\n j 2 2 m m 2m\r\n k=1 k=1\r\n 3 2 \u03c4j\u22121 h 2i 3 2 2\r\n \u03b7 L XEk\u2207F(Xr)k \u03b7 L \u03c3 \u03c4 (\u03c4 \u22121)\r\n j [2\u03b2(\u03c4 \u22121)+\u03c4 (\u03c4 \u22121)] F + j j j (55)\r\n 2 j j j m 2\r\n r=1 h i\r\n \u03c4 \u0014 \u0012 \u0013\u0015 \u03c4 2\r\n j j E k\u2207F(X )k 2 2\r\n \u03b7 X \u03b7 \u03b2 X k F \u03b7 L\u03c3 \u03c4\r\n j 2 j j j\r\n \u2264\u2212 Ek\u2207F(x )k \u2212 1\u2212\u03b7 L +1 \u00b7 + +\r\n 2 k 2 j m m 2m\r\n k=1 h i k=1\r\n \u03c4 2\r\n 3 2 j E k\u2207F(X )k 3 2 2\r\n \u03b7 L X r F \u03b7 L \u03c3 \u03c4 (\u03c4 \u22121)\r\n j [2\u03b2(\u03c4 \u22121)+\u03c4 (\u03c4 \u22121)] + j j j . (56)\r\n 2 j j j m 2\r\n r=1\r\n Note that when the learning rate satis\ufb01es:\r\n \u03b72L2(\u03c4 \u22121)(2\u03b2 +\u03c4 )+\u03b7 L\u0012\u03b2 +1\u0013\u22641, (57)\r\n j j j j m\r\n wehave\r\n \u03c4j 2 2 3 2 2\r\n \u0002 \u0003 \u03b7 X \u03b7 L\u03c3 \u03c4 \u03b7 L \u03c3 \u03c4 (\u03c4 \u22121)\r\n j 2 j j j j j\r\n E F(x\u03c4j+1)\u2212F(x1) \u2264\u2212 2 Ek\u2207F(xk)k + 2m + 2 . (58)\r\n k=1\r\n Adaptive CommunicationStrategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD\r\n P\r\n Suppose l = j\u22121\u03c4 +1isthe\ufb01rstindexinthej-thlocal update period. Without loss of generality, we substitute the\r\n j r=0 r\r\n local index by global index:\r\n \u03c4j 2 2 3 2 2\r\n \u0002 \u0003 \u03b7 X \r \r \u03b7 L\u03c3 \u03c4 \u03b7 L \u03c3 \u03c4 (\u03c4 \u22121)\r\n j 2 j j j j j\r\n \r \r\r\n E F(x ) \u2212F(x ) \u2264\u2212 E \u2207F(x ) + + . (59)\r\n l l l +k\u22121\r\n j+1 j 2 j 2m 2\r\n k=1\r\n Summingoveralllocal periods from j = 0 to j = R, one can obtain\r\n \u03c4\r\n R j 2 R 2 2 R\r\n 1 X X \r \r L\u03c3 X L \u03c3 X\r\n 2 2 3\r\n \r \r\r\n E[F(x )\u2212F(x )]\u2264\u2212 \u03b7 E \u2207F(x ) + \u03b7 \u03c4 + \u03b7 \u03c4 (\u03c4 \u22121). (60)\r\n l 1 j l +k\u22121 j j j j j\r\n R 2 j 2m 2\r\n j=0 k=1 j=0 j=0\r\n After minor rearranging, it is easy to see\r\n \uf8ee R \u03c4j \uf8f9 2 R R\r\n X X\r \r L\u03c3 X X\r\n 2 \u2217 2 2 2 3\r\n \uf8f0 \r \r \uf8fb\r\n E \u03b7 \u2207F(x ) \u22642[F(x )\u2212F ]+ \u03b7 \u03c4 +L \u03c3 \u03b7 \u03c4 (\u03c4 \u22121). (61)\r\n j l +k\u22121 1 j j j j j\r\n j m\r\n j=0 k=1 j=0 j=0\r\n That is,\r\n \"P P \r \r # P P\r\n R \u03c4j \r 2 R 2 R 3\r\n \u03b7 \u2207F(x )\r 2[F(x )\u2212F\u2217] L\u03c32 \u03b7 \u03c4 \u03b7 \u03c4 (\u03c4 \u22121)\r\n j l +k\u22121 j j j\r\n E j=0 k=1 j \u2264 1 + j=0 j +L2\u03c32 j=0 j . (62)\r\n P P P P\r\n R \u03b7 \u03c4 R \u03b7 \u03c4 m R \u03b7 \u03c4 R\u22121\u03b7 \u03c4\r\n j=0 j j j=0 j j j=0 j j j=0 j j\r\n F.3 Asymptotic Result (Theorem 3)\r\n In order to let the upper bound (62) converges to zero as R \u2192 \u221e, a suf\ufb01cient condition is\r\n R R R\r\n lim X\u03b7 \u03c4 =\u221e, lim X\u03b72\u03c4 <\u221e, lim X\u03b73\u03c42<\u221e. (63)\r\n R\u2192\u221e j j R\u2192\u221e j j R\u2192\u221e j j\r\n j=0 j=0 j=0\r\n Here, we complete the proof of Theorem 3.\r\n F.4 Simpli\ufb01ed Result\r\n Wecanobtainasimpli\ufb01edresult when the learning rate is \ufb01xed. To be speci\ufb01c, we have\r\n \"P P \r \r # P P\r\n R \u03c4j 2 R R\u22121\r\n \r \r \u2217 2\r\n \u2207F(x ) 2[F(x )\u2212F ] \u03b7L\u03c3 \u03c4 \u03c4 (\u03c4 \u22121)\r\n l +k\u22121 j j j\r\n E j=0 k=1 j \u2264 1 + j=0 +\u03b72L2\u03c32 j=0 (64)\r\n P P P P\r\n R \u03c4 \u03b7 R \u03c4 m R \u03c4 R\u22121\u03c4\r\n j=0 j j=0 j j=0 j j=0 j\r\n 2[F(x )\u2212F\u2217] \u03b7L\u03c32 PR \u03c4 (\u03c4 \u22121)!\r\n \u2264 1 + +\u03b72L2\u03c32 j=0 j j . (65)\r\n P P\r\n \u03b7 R \u03c4 m R\u22121\u03c4\r\n j=0 j j=0 j\r\n P\r\n If we choose the total iterations K = R \u03c4 ,then\r\n j=0 j\r\n \"P # P !\r\n K 2 \u2217 2 R \u03c42\r\n E k=1k\u2207F(xk)k \u22642(F(x1)\u2212F ) + \u03b7L\u03c3 +\u03b72L2\u03c32 j=0 j \u22121 . (66)\r\n P\r\n K \u03b7K m R \u03c4\r\n j=0 j\r\n", "award": [], "sourceid": 124, "authors": [{"given_name": "Jianyu", "family_name": "Wang", "institution": "Carnegie Mellon University"}, {"given_name": "Gauri", "family_name": "Joshi", "institution": "Carnegie Mellon University"}]}