{"title": "Automatically batching control-intensive programs for modern accelerators", "book": "Proceedings of Machine Learning and Systems", "page_first": 390, "page_last": 399, "abstract": "We present a general approach to batching arbitrary computations for accelerators such as GPUs.  We show orders-of-magnitude speedups using our method on the No U-Turn Sampler (NUTS), a workhorse algorithm in Bayesian statistics.  The central challenge of batching NUTS and other Markov chain Monte Carlo algorithms is data-dependent control flow and recursion.  We overcome this by mechanically transforming a single-example implementation into a form that explicitly tracks the current program point for each batch member, and only steps forward those in the same place.  We present two different batching algorithms: a simpler, previously published one that inherits recursion from the host Python, and a more complex, novel one that implemenents recursion directly and can batch across it. We implement these batching methods as a general program transformation on Python source.  Both the batching system and the NUTS implementation presented here are available as part of the popular TensorFlow Probability software package.", "full_text": "                       AUTOMATICALLYBATCHINGCONTROL-INTENSIVE PROGRAMS FOR\r\n                                                          MODERNACCELERATORS\r\n                           Alexey Radul1 Brian Patton1 Dougal Maclaurin1 Matthew D. Hoffman2 Rif A. Saurous3\r\n                                                                        ABSTRACT\r\n                    Wepresent a general approach to batching arbitrary computations for accelerators such as GPUs. We show\r\n                    orders-of-magnitude speedups using our method on the No U-Turn Sampler (NUTS), a workhorse algorithm in\r\n                    Bayesian statistics. The central challenge of batching NUTS and other Markov chain Monte Carlo algorithms is\r\n                    data-dependent control \ufb02ow and recursion. We overcome this by mechanically transforming a single-example\r\n                    implementation into a form that explicitly tracks the current program point for each batch member, and only steps\r\n                    forward those in the same place. We present two different batching algorithms: a simpler, previously published\r\n                    one that inherits recursion from the host Python, and a more complex, novel one that implemenents recursion\r\n                    directly and can batch across it. We implement these batching methods as a general program transformation on\r\n                    Python source. Both the batching system and the NUTS implementation presented here are available as part of the\r\n                    popular TensorFlow Probability software package.\r\n               1    INTRODUCTION                                                  and communication overhead. Alternately, the same API\r\n               Modern machine learning accelerators such as GPUs are              functions can be used to construct an operation graph to be\r\n               oriented around Single Instruction Multiple Data (SIMD)            executed all at once. This is the so-called TensorFlow graph\r\n               parallelism\u2014doing the same thing to each item of a big             mode. The advantage is that graphs can be saved, loaded,\r\n               array of data at once. Machine learning programs optimized         and optimized before being run, and suffer less dispatch\r\n               for such accelerators generally consist of invoking kernels,       overhead. The disadvantage is that graph computations can-\r\n               where each kernel is a separately hand-tuned accelerator           not be interleaved with the host Python, and in particular\r\n               program for a speci\ufb01c function. Good utilization of the            graph mode cannot represent recursive computations. A\r\n               accelerator comes of making relatively few kernel calls,           third option is to further compile the graph with XLA (The\r\n               with each kernel processing a relatively large amount of           XLATeam,2017). XLA imposes even more restrictions,\r\n               data. In the case of a typical neural network workload, the        such as statically resolving the shapes of all intermediate\r\n               kernels would be \u201cmatrix multiplication\u201d or \u201cconvolution\u201d,         arrays, but offers the additional bene\ufb01t of fusing kernels\r\n               and the call sequence would encode the architecture of the         together, which reduces dispatch overhead even more.\r\n               neural network.                                                    Goodperformance in this programming style depends heav-\r\n               Let\u2019s brie\ufb02y look at the resulting programming model. This         ily on vectorization, both within the kernels and at the level\r\n               review is worded in the TensorFlow (Abadi et al., 2015)            of kernel inputs. One very common strategy for vectorizing\r\n               ecosystem, since that\u2019s the setting for our work, but other        machine learning programs is so-called batching: process-\r\n               machine learning frameworks are broadly similar. The top-          ing a batch of independent inputs in lock-step in order to\r\n               level program is generally written in Python, calling Ten-         get more play for vectorization. Batching can also reduce\r\n               sorFlow API functions that correspond to kernels such as           per-input memory pressure: in the case of a neural network\r\n               matrix multiplication. These functions can be executed im-         with N features, each input has size O(N), whereas the\r\n                                                                                  weight matrices can easily have size O(N2). Running mul-\r\n               mediately, in the so-called TensorFlow Eager mode. In this         tiple inputs through the layers of the network in lock-step\r\n               case they can be arbitrarily interleaved with the host Python,     can re-use each weight matrix for many examples before\r\n               including control \ufb02ow; but suffer corresponding dispatch           having to evict it from memory caches in order to load the\r\n                  1Google, Cambridge, Massachusetts, USA 2Google, New York,       next one.\r\n               NewYork,USA3Google,MountainView,California,USA.Cor-\r\n               respondence to: Alexey Radul <axch@google.com>.                    It is standard practice in machine learning frameworks such\r\n                                                                                  as TensorFlow or PyTorch (Paszke et al., 2017) to code the\r\n               Under review by the Systems and Machine Learning (SysML)           kernels to accept extra input dimensions and operate elemen-\r\n               Conference. Copyright 2019 by the author(s).                       twise across them. Consequently, coding a batched version\r\n                                                                     Autobatching\r\n                           nvariable              leftvariable             Program counter\r\n                      Batch                     Batch                     Batch\r\n                       3    7     4     5        1    8     2     3       9     9     9    9        1 def fibonacci(n):\r\n                 stack                                                                              2    cond = n <= 1\r\n                                                                                                    3    if cond:\r\n                 Python2    6     3     4                                 7     7     7    7        4       return 1\r\n                                                                                                    5    else:\r\n                                                                                                    6       n2 = n - 2\r\n                       0    4     1     2                                 4     7     4    7        7       left = fibonacci(n2)\r\n                                                                                                    8       n1 = n - 1\r\n                                                                                                    9       right = fibonacci(n1)\r\n                                                                                                   10       return left + right\r\n                            2           0                                       2          2\r\n              Figure 1. Runtime operation of a locally, statically auto-batched recursive Fibonacci program. This snapshot occurs on the batch of\r\n               inputs 3,7,4,5. The batching transformation adds storage for all the batch members and handles divergent control \ufb02ow by masking.\r\n              Therecursion is handled in Python. In this example, there are two \u201cactive\u201d logical threads about to execute lines 2-3 of the Fibonacci\r\n               program, highlighted in red. There are also two logical threads suspended one Python stack frame earlier, waiting for the active threads to\r\n               re-converge with them so they can all return from that frame. The runtime cannot batch together logical threads with different call stacks,\r\n               because those stacks are embedded in the runtime\u2019s Python-level call stack. The left variable has no value in most of the shown stack\r\n               frames because the program hasn\u2019t assigned it yet.\r\n               of a straightline program is relatively straightforward, if       \u2022 Introduce program-counter autobatching (Section 3),\r\n               somewhattedious and error-prone. Simple neural networks             a global, static program transformation for batching\r\n               being straightline, batch training is the norm. Obstacles           programs with arbitrary control \ufb02ow, and materializing\r\n               arise, however, when one wishes to batch a program with             recursion into an underlying data\ufb02ow system.\r\n               control \ufb02ow, such as conditionals or variable-length loops.\r\n              Then it becomes necessary to keep track of which batch             \u2022 Demonstrate that program-counter autobatching can\r\n               membertakeswhichbranchofeachconditional, and avoid                  successfully accelerate the No U-Turn Sampler, a clas-\r\n               or ignore computations on batch members at the wrong                sic algorithm from Bayesian statistics, by compiling its\r\n               program point. The dif\ufb01culty of doing this by hand im-              recursion into explicit stack management, and by stati-\r\n               pedes using sophisticated classical algorithms in machine           cally constructing a schedule for running it on batches\r\n               learning. Despite the impedance, people have used tree              of inputs.\r\n               searches (Silver et al., 2016), optimization routines (Amos\r\n              &Kolter, 2017) and ordinary differential equations solvers         \u2022 Provide, using the same vocabulary, a formal descrip-\r\n              (Chen et al., 2018) in machine learning work; what else              tion of local static autobatching (Section 2). This is a\r\n               could we accomplish if it were easier?                              simpler and lower-overhead batching transformation\r\n              Additional obstacles arise when trying to run a recursive pro-       with less batching power in the recursive case.\r\n               gramonamodernmachinelearningframework,in batch or\r\n               otherwise, because the data\ufb02ow graph representation cannot        \u2022 Survey (Section 5) the local static autobatching sys-\r\n               execute recursion natively. This is as true in XLA or Tensor-       tems (Agarwal, 2019; Bradbury & Fu, 2018; Bradbury\r\n               Flow graph mode as it is in other graph-oriented machine            et al., 2017\u20132019) that have been implemented for sev-\r\n               learning frameworks like Caffe (Jia et al., 2014). The user is      eral machine learning frameworks.\r\n               therefore forced to fall back to eager-style execution, paying\r\n               more communication overhead. If machine learning is to            \u2022 Directly compare these two autobatching strategies on\r\n               bene\ufb01t fully from the last 60 years of computer algorithm           a test problem from Bayesian statistics (Section 4).\r\n               development, we must be able to run recursive algorithms\r\n               reasonably ef\ufb01ciently.                                         Program-counter autobatching is available as a module in\r\n               Our goal in this paper is to push the boundary of what         the popular TensorFlow Probability (The TFP Team, 2018\u2013\r\n               classical algorithms can ef\ufb01ciently execute on accelerators,   2019; Dillon et al., 2017) software package. That module\r\n               in the context of modern machine learning frameworks. In       also implements a local static autobatching variant for com-\r\n               particular, we                                                 parison.\r\n                                                                        Autobatching\r\n                   Program P ::= [F]                                             Algorithm 1 Local static autobatching\r\n                   Function F       ::= input x,body [B],output y                  Input: Function F with I basic blocks Bi, input variable\r\n                      Block B ::= [op],t                                           x, and output variable y;\r\n                 Operation op ::= Primitivey = f(x)                                Input: Batch size Z;\r\n                                          | Call y = F(x)                          Input: Data array T with leading dimension Z;\r\n                Terminator t        ::= Jumpi|Branchxij |Return                    Input: Active set A \u2286 {0,1,...,Z \u22121}.\r\n                               f    ::= sin | cos | ...                            Initialize length Z program counter pc = [0,0,...,0]\r\n               Figure 2. Syntax of locally batchable programs. We use [\u00b7] to       Initialize x = T\r\n               denote ordered lists. The symbols x, y range over variable names,   while (for any b \u2208 A, pcb < I) do\r\n                                                                                      Set block index i = min       pc\r\n               and i, j index blocks within the same function. We present a unary                               b\u2208A \u2032 b\r\n               syntax for succinctness; the n-ary generalization is standard.         Computelocally active set A = {b \u2208 A|pcb = i}\r\n                                                                                      for op \u2208 B do\r\n                                                                                                  i\r\n                                                                                         if op is (Primitive y = f(x)) then\r\n               2    LOCALSTATICAUTOBATCHING                                                Computeoutputs o = f(x)\r\n                                                                                           Set y \u2032 = o \u2032\r\n               Thesimplest batching strategy (whether automated or hand-                         A      A\r\n               coded) is to retain the graph of the computation as-is and                else if op is (Call y = G(x)) then\r\n                                                                                           Recursively compute outputs:\r\n               just transform every operation into a batched equivalent. We                o = Local-static(G,Z,x,A\u2032)\r\n               call this local static autobatching. Intuitively, it\u2019s \u201clocal\u201d              Set y \u2032 = o \u2032\r\n               because the pattern of composition of operations doesn\u2019t                          A      A\r\n               change, and every operation can be transformed on its own;                endif\r\n               and it\u2019s \u201cstatic\u201d because the batching schedule doesn\u2019t de-            endfor\r\n               pend on the input data, and can thus be computed before                if ti is Jump j then\r\n               starting execution.                                                       Set pcA\u2032 = j\r\n                                                                                      else if ti is Branch x j k then\r\n               Whenextendingthisideatoprogramswithcontrol\ufb02ow,itis                        for b \u2208 A\u2032 do\r\n               necessary to at least introduce a mask of which batch mem-                  Set pcb = j if xb otherwise pcb = k\r\n               bers are \u201ccurrently active\u201d. One then arranges to execute                 endfor\r\n               every control path that at least one batch member follows,             else if ti is Return then\r\n               and avoid or ignore each computation for each batch mem-                  Set pcA\u2032 = I\r\n               ber that did not take that path. If the program being batched          endif\r\n               is recursive, the recursion still has to be carried out by the      endwhile\r\n               control language, i.e., Python. The runtime operation thus          return Current value of y\r\n               looks like Figure 1.\r\n               Local static autobatching can be implemented in many\r\n               styles.  For the sake of clarity, we will describe it as a        it in batch. We then update the data storage and program\r\n               nonstandard interpretation of a simple control \ufb02ow graph          counters of just those locally active batch members. Repeat\r\n               language, given in Figure 2. In addition to eliminating many      until all active batch members have exited the function, then\r\n               incidental considerations, this presentation aligns with the      return.\r\n               presentation of program-counter autobatching in Section 3,        If the block we are executing ends in a branch (i.e., the pre-\r\n               which will be a (different) nonstandard interpretation of a       lude of a source language if statement), the locally active\r\n               very similar language. Going through this presentation \ufb01rst       batch members may diverge, in that some may move to the\r\n               also allows us to compare to other local static autobatching      true branch and some to the false. They will converge again\r\n               systems more precisely, in Section 5.                             when both of those branches complete, and we continue\r\n               Thenonstandardinterpretation itself is given in Algorithm 1.      after the end of the if.\r\n               In addition to storage for all the batch member inputs, we        If the block we are executing contains a (potentially recur-\r\n               maintain an active set (initially the whole batch) and a pro-     sive) call to a function the user asked us to auto-batch, we\r\n               gram counter (initially the start of the entry point). The        appeal to the host language\u2019s function call facility. The only\r\n               active set is a mask\u2014all inactive batch members are ignored       trick is to update the active set in the recursive autobatching\r\n               and never modi\ufb01ed until they become active. The program           invocation to include only the locally active batch members\r\n               counter gives the program point (as a basic block index)          (i.e., those whose program counter was at that call).\r\n               each active batch member is waiting to execute. The exe-\r\n               cution model is simple: at each step, we select some basic        Whydoesthiswork? Consider this runtime from the point\r\n               block that has at least one active batch member and execute       of view of one batch member. It wants to execute some\r\n                                                                             Autobatching\r\n                 Batch                         Batch                        Batch\r\n                   4     7      8     9         5      3     5      3         9      7     7      7        1 def fibonacci(n):\r\n                   2            4     7         3            3                9      9     9      7        2    cond = n <= 1\r\n                                2                                             7            9      9        3    if cond:\r\n                                                                              2            7               4        return 1\r\n                                                                                           2               5    else:\r\n                                                                                                           6        n2 = n - 2\r\n                       Stack for n                Stack for left               Program counter             7        left = fibonacci(n2)\r\n                                                                                                           8        n1 = n - 1\r\n                   2     1      3     2         2      1     2      1         4      2     5      3        9        right = fibonacci(n1)\r\n                    nstack pointers            leftstackpointers               PCstackpointers            10        return left + right\r\n                Figure 3. Runtime operation of a program counter auto-batched recursive Fibonacci program. This snapshot occurs on the batch of inputs\r\n                6,7,8,9. In addition to the batch dimension (across), the batching transformation also augments every non-temporary variable from the\r\n                program with a stack dimension (down), and an array of stack pointers. Additionally, the runtime maintains a program counter\r\n                variable that records which block each logical thread is waiting to execute. At each time step, the runtime selects a basic block to run (lines\r\n                2-3 in this example) and updates the state and program counter of the logical threads executing that block (highlighted in red). Because\r\n                recursive state is captured explicitly in the arrays storing the data, the runtime doesn\u2019t need to itself rely on recursion in Python (the host\r\n                language). This means both that it can be executed in TensorFlow\u2019s graph mode, and that it can let logical threads re-converge on function\r\n                calls, even at different stack depths. Note that the stack for the n variable will only hold values in frames where the program counter\r\n                hasn\u2019t moved past line 8, where n is last used. Conversely, left is only pushed in frames where the program counter is past line 7.\r\n                sequence of basic blocks, as given by the edits to its pro-           ate arrays will have statically indeterminate size, making the\r\n                gram counter. Every time the runtime runs one of those                gather-scatter approach less effective on platforms like the\r\n                basic blocks, it updates the state of that batch member the           XLAcompilerforGoogle\u2019sTPUs(TheXLATeam,2017)\r\n                samewayitwouldifthebatchhadsize1. Andeverytime                        that statically infer array shapes.\r\n                the runtime runs some other block, it doesn\u2019t update the              The second signi\ufb01cant free choice in this runtime is the\r\n                batch member at all. The only way this can fail is if some            heuristic for selecting which basic block to run next. As\r\n                underlying batch operation in the platform doesn\u2019t treat              long as we don\u2019t starve any blocks, any selection criterion\r\n                batch members independently (e.g., if an error in one batch           will lead to a correct end result. Algorithm 1 encodes a sur-\r\n                membercausesanexceptionwhichabortsexecution of all                    prisingly effective choice: always run the earliest available\r\n                of them) or if some batch member doesn\u2019t terminate and                block in program order. This has the merit of being (rela-\r\n                starves the others.                                                   tively) predictable by the user; but more re\ufb01ned heuristics\r\n                There are two signi\ufb01cant free choices in this runtime. The            are de\ufb01nitely possible.\r\n                \ufb01rst is how to execute a primitive operation on some batch            In our implementation, the frontend for this is a Python-\r\n                membersbutnotothers. Algorithm 1 is written in masking                embeddedcompiler. That is, it\u2019s a user-invoked AST trans-\r\n                style: we run the primitive on all the batch members, and             formation based on AutoGraph (Moldovan et al., 2018) that\r\n                just ignore the results of the ones that were at different points     converts the user program into the form given in Figure 2.\r\n                in the program. This is simple and has very low in-system             All the user\u2019s actual computations become Primitive opera-\r\n                overhead, because masking is a cheap operation. The down              tions, and the control and recursion constructs are encoded\r\n                side is that it wastes computation on the batch members               in a standard way in Jump,Branch,Call, and Return in-\r\n                that are going to be masked out, which can be signi\ufb01cant              structions.\r\n                if batch utilization is low. There is also the subtlely that\r\n                this extra computation happens with junk data, which may              3     PROGRAMCOUNTERAUTOBATCHING\r\n                trigger spurious failures in the underlying platform.\r\n                The other option for batching the primitive operations is             Thelocal static autobatching discussed in Section 2 has an\r\n                to use the indices of the locally active batch members to             interesting limitation. Because it relies on the Python stack\r\n                gather the inputs into a smaller array, perform just the live         to implement recursion, it cannot batch operations across\r\n                computation, and then scatter the results back into the full          different (recursive) calls to the same user function. So two\r\n                runtime state. This avoids wasting computation and avoids             batch memberscouldbetryingtoexecutethesamecodeand\r\n                computing on junk data, but gathering and scattering are              not be batchable because the system doesn\u2019t have a long-\r\n                moreexpensive than masking. Furthermore, the intermedi-               enough sight line to connect them. And, of course, relying\r\n                                                                         Autobatching\r\n                   Program P ::= inputx,code[B],outputy                           Algorithm 2 Program counter autobatching\r\n                       Block B ::= [op],t                                            Input: Program P with I basic blocks Bi, input variable\r\n                  Operation op ::= Pushy = f(x)|Popx                                 x, and output variable y;\r\n                Terminator t        ::= Jumpi|Branchxij                              Input: Batch size Z; Stack depth limit D;\r\n                                          | PushJumpij |Return                       Input: Data array T with leading dimension Z.\r\n                               f    ::= sin | cos | ...                              Initialize D-by-Z program counter pc = [0,0,...,0]\r\n               Figure 4. Syntax of program counter batchable programs. We use        Initialize length Z stack indexes pcstack = [0,0,...,0]\r\n               [\u00b7] to denote ordered lists. The symbols x, y range over variable     for variable v do\r\n               names, and i, j index blocks of the program. This syntax is also        Initialize v to zeros with leading dimensions D,Z\r\n               unary for succinctness. The difference from locally autobatched         Initialize length Z indexes vstack = [0,0,...,0]\r\n               programs (Figure 2) is that all control \ufb02ow graphs are merged,        endfor\r\n               and Call operations are replaced with explicit stack manipulation     PUSHT ontox\r\n                                                                                                           top\r\n               operations. Push and Pop save and load data; PushJump i j             Initialize length Z pc    =pc[pcstack]\r\n               jumps to block i after setting up a return to block j; and Return     while any pctop < I do\r\n               returns by popping the program counter stack.                           Set next block index i = minpctop\r\n                                                                                       Computelocally active set A = {b|pctop = i}\r\n                                                                                       for op \u2208 B do                           b\r\n                                                                                                   i\r\n               on Python to manage the recursion imposes communication                    if op is Push y = f(x) then\r\n               costs and limits the optimizations the underlying machine                    Computextop = x[x           ]\r\n                                                                                                                   stack\r\n                                                                                                                        top\r\n               learning framework can do.                                                   Computeoutputs o = f(x         )\r\n                                                                                            PUSHo ontoy\r\n               We can serve two purposes with one intervention by im-                                A        A\r\n               plementing the stack within the autobatching system. We                    else if op is Pop x then\r\n                                                                                            POPx\r\n               choose to give each program variable its own stack (by ex-                          A\r\n               tending the relevant array with another dimension), getting                endif\r\n               a runtime state that looks like Figure 3. The layout of the             endfor\r\n                                                                                       if ti is Jump j then\r\n               \ufb01gure is intentionally the same as Figure 1 to emphasize                   Set pctop = j\r\n               that we are representing all the same stuff, just in a different                 A\r\n               way.                                                                    else if ti is Branch x j k then\r\n                                                                                          Computextop = x[xstack]\r\n               Toimplement this, we want a slightly different control \ufb02ow                 for b \u2208 A do\r\n                                                                                                   top         top              top\r\n               graphlanguage,showninFigure4. Sincetheruntimeisnow                           Set pc     =jifx       otherwise pc     =k\r\n                                                                                          endfor b             b                b\r\n               managing the stacks itself, we replace the Call instruction\r\n               with explicit (per-variable) Push and Pop instructions, as              else if ti is PushJump j k then\r\n               well as PushJump for entering function bodies. The Push                    Set pctop = k\r\n                                                                                                A\r\n               also computes the value to write to the top of the variable\u2019s              PUSHjontopcA\r\n               stack. The language is otherwise the same as Figure 2, and              else if ti is Return then\r\n               indeed our implementation compiles to the latter \ufb01rst and                  POPpcA\r\n               then lowers from there to the former.                                   endif\r\n                                                                                     endwhile\r\n                                                                                     return Current value of y[y       ]\r\n               3.1   Runtime                                                                                      stack\r\n               The runtime is spelled out in Algorithm 2. As compared\r\n               with local static autobatching (Algorithm 1), the active set       mode)TensorFloworXLAthatdoesnotsupportrecursion\r\n               nolonger persists across steps, while the explicit program         natively. In TensorFlow, the code corresponding to the main\r\n               counter takes on a more central role (hence the name). The         loop in Algorithm 2 looks like Figure 5. The loop body is\r\n               program counter now has a stack dimension of its own. The          a dispatch to the correct block based on the dynamic value\r\n               locally active set can now include batch membersatdifferent        of the next program counter. The main trick, common in\r\n               stack depths, giving the runtime a chance to batch their           partial evaluation (Jones et al., 1993), is that the block index\r\n               computations together. Consequently, computations can              pconline 3 is static at TF graph build time. The block\r\n               converge by calling into the same subroutine from different        function (not shown) is hence just a direct transcription of\r\n               code locations, and conversely diverge by returning thereto.       the loop body in Algorithm 2. Picking out the pcth program\r\n               Themajoradvantage of managing variable stacks and the              block and interpreting that block\u2019s operations will happen\r\n               program counter is that the runtime is no longer itself recur-     once when constructing the TensorFlow graph, leaving just\r\n               sive, so can be coded completely in a system like (graph-          the Tensor operations needed to manipulate the autobatched\r\n                                                                        Autobatching\r\n               variables. In effect, lines 3\u20134 are the \ufb01nal stage of our com-    program entirely without variable stacks (except for the pro-\r\n               piler, lowering the entire autobatched program from our           gramcounter itself). It will thus replicate the performance\r\n               representation of Figure 4 to raw TensorFlow operations.          of local static autobatching without needing a host-language\r\n                                                                                 stack for (non-recursive) function calls, while also being\r\n             1 while next_pc < halt_pc:                                          able to batch across them. Unlike a tracing-based system\r\n             2    all_vars = tf.switch_case(next_pc,                             such as JAX (Bradbury et al., 2017\u20132019), this compiled\r\n             3       [block(pc, all_vars)                                        approach also doesn\u2019t amount to inlining all function calls,\r\n             4         for pc in valid_pcs])                                     socanautobatchaprogramwithsigni\ufb01cantsubroutinereuse\r\n             5    next_pc = tf.reduce_min(                                       without combinatorial explosion in code (or traced graph)\r\n             6       get_top(get_pc_var(all_vars)))                              size.\r\n               Figure 5. Pseudocode for implementing the main loop of Algo-      4    EXPERIMENTS\r\n               rithm 2 in TensorFlow. The main idea is to dispatch at runtime\r\n               onthedynamicnext pc(line2), to the block compiled with the        We evaluate autobatching on the No U-Turn Sampler\r\n               correct static pc (line 3). See text.                             (NUTS),aworkhorsegradient-based Markov-chain Monte\r\n                                                                                 Carlo method widely used in Bayesian statistics. We chose\r\n               3.2   Optimizations                                               NUTSasatesttargetbecause(i) the standard presentation\r\n                                                                                 is a complex recursive function, prohibitively dif\ufb01cult to\r\n               The price we pay for implementing our own stack is that           batch by hand, and (ii) batching across multiple independent\r\n               reads from the stack structure must gather according to the       chains can give substantial speedups.\r\n               stack depths (which may vary across batch members), and           Ourresults show that batching on GPUs enables NUTS to\r\n               writes must correspondingly scatter. We implement \ufb01ve             ef\ufb01ciently scale to thousands of parallel chains, and get infer-\r\n               compiler optimizations to reduce this cost:                       ential throughput well beyond existing CPU-based systems\r\n                                                                                 such as Stan. By demonstrating performance gains from\r\n                 1. Each variable in the program being auto-batched gets         batching NUTS, we hope to contribute to a broader practice\r\n                    its own stack, so we can optimize those stacks indepen-      of running large numbers of independent Markov chains,\r\n                    dently. In particular, we can arrange the stack opera-       for more precise convergence diagnostics and uncertainty\r\n                    tions in a per-variable caller-saves stack discipline to     estimates.\r\n                    set up for later pop-push elimination.                       Wetest autobatched NUTS on two test problems:\r\n                 2. Thecompilerstaticallydetectswhethereachvariableis\r\n                    live across an iteration of the runtime loop. Those that        \u2022 Exploring a 100-dimensional correlated Gaussian dis-\r\n                    are not are temporary and don\u2019t need to be touched by              tribution.\r\n                    the autobatching system at all: they exist only inside          \u2022 Inference in a Bayesian logistic regression problem\r\n                    basic block executions.                                            with 10,000 synthetic data points and 100 regressors.\r\n                 3. The compiler also statically detects whether each vari-      Wetest three forms of autobatching NUTS:\r\n                    able is live across a recursive function call that might\r\n                    need to reuse it (at a different stack depth). Those that       \u2022 Program counter autobatching, compiled entirely with\r\n                    are not do not need a stack or a stack pointer, and auto-          XLA;\r\n                    batching only amounts to updating their top value with\r\n                    a mask to select only the active batch members.                 \u2022 Local static autobatching, executed entirely with Ten-\r\n                 4. For those variables that do require stacks, the runtime            sorFlow Eager; and\r\n                    caches the top of each stack variable, so that repeated         \u2022 Ahybrid: Runningthecontroloperationsoflocalstatic\r\n                    reads do not require large gather operations.                      autobatching in TensorFlow Eager, but compiling the\r\n                 5. Finally, the compiler also statically cancels pairs of             straight-line components (basic blocks) with XLA.\r\n                    Popfollowed by Push that have no intervening reads,          Thepurposeofthelatter is to try and tease apart the bene\ufb01ts\r\n                    and converts the latter into an in-place Update instruc-     of compilation speci\ufb01cally for control \ufb02ow versus compil-\r\n                    tion (not shown) that only interacts with the cached top     ing (and fusing together) straightline sequences of kernel\r\n                    of each stack.                                               invocations. It should be noted that identifying the basic\r\n                                                                                 blockstocompileseparatelyisanontrivialprogramtransfor-\r\n               An important consequence of the above optimizations is            mationinits ownright. It \ufb01ts conveniently into our software\r\n               that program counter autobatching will run a non-recursive        framework, but represents considerable work to do by hand.\r\n                                                                      Autobatching\r\n               Figure 6. Performance of auto-batched No U-Turn Sampler on the Bayesian logistic regression problem (100 latent dimensions, 10,000\r\n               data points). The batch size refers to the number of chains running in tandem. The reported gradients are the total across all chains,\r\n               excluding waste due to synchronization. We compare the performance of program counter autobatching compiled with XLA to our local\r\n               static autobatching executed in TensorFlow\u2019s Eager mode. We also include two baselines. One is the same program executed directly in\r\n               Eager mode without autobatching (perforce running one batch member at a time). The other is the widely used and well-optimized Stan\r\n               implementation of (a variant of) the same NUTS algorithm. Batching provides linear scaling on all tested platforms, until the underlying\r\n               hardware saturates. See text for details of the experimental setup.\r\n               4.1  Runtimeonlogistic regression                               target density; (2) it beats fully compiled autobatching by\r\n               In Figure 6, we measure the number of model gradient            avoiding the overhead of gathering from and scattering to\r\n               evaluations per second we can get out of batching NUTS in       per-variable stack arrays; (3) it beats fully Eager local au-\r\n               different ways. The main effect we see is GPU throughput        tobatching by avoiding per-leaf-kernel dispatch overheads;\r\n               scaling linearly with batch size. We also see the speedup       but (4) it\u2019s slower at low batch sizes because of per-fused-\r\n               from avoiding cycling to Python (on the host CPU!) to           kernel invocation costs. We leave a complete investigation\r\n               implement the recursion.                                        of this phenomenon to future work.\r\n               The behavior when running entirely on CPU is more nu-           Afew details of the experimental setup. These measure-\r\n               anced. CPUs are superb at control \ufb02ow and recursion as it       mentsareforthesyntheticBayesianlogisticregressionprob-\r\n               is, so the main effect of batching seems to be to amortize      lem. The measured time counts only a warm run, excluding\r\n               away platform overhead, until we match the performance          compilation, the one-time TensorFlow graph construction,\r\n               of the Stan system\u2019s long-optimized custom C++ at a batch       etc. This allows measurements for small batch sizes not\r\n               size of a few hundred\u2014or just ten for compiling fully with      to be swamped by O(1) overhead. The CPU computations\r\n               XLA,whoseper-leaf-kernel overhead is much smaller.              were run on a shared 88-core machine with 100 GiB of\r\n                                                                               RAMallocated to the benchmark, in 32-bit \ufb02oating-point\r\n               At very large batch sizes, however, the hybrid strategy of      precision. The GPU computations were run on a dedicated\r\n               running local static autobatching in TensorFlow Eager but       Tesla P100 GPU, also in 32-bit precision. In all cases, we\r\n               compiling the basic blocks with XLA outperforms all other       slightly modi\ufb01ed the published NUTS algorithm to take 4\r\n               NUTSimplementations we tested. We are not quite sure            steps of the leapfrog integrator at each leaf of the NUTS\r\n               whythis happens, but we hypothesize that (1) it beats Stan      tree, to better amortize the control overhead. This has no\r\n               because of better memory locality in batch evaluation of the    effect on the soundness of the algorithm. The timings are\r\n                                                                     Autobatching\r\n              Figure 7. Utilization of batch gradient computation on the correlated Gaussian test problem. Utilization is less than 100% above 1 batch\r\n              memberbecausedifferent batch members choose to use different numbers of gradients at each trajectory. We can see from the local-static\r\n              line that on this problem, the longest trajectory that NUTS chooses at any iteration tends to be about four times longer than the average.\r\n              Program counter autobatching recovers more utilization by batching gradients across 10 consecutive NUTS trajectories, instead of having\r\n              to synchronize on trajectory boundaries.\r\n              best of \ufb01ve independent runs. Due to technical limitations,     5    RELATEDWORK\r\n              the Stan baseline was run on different hardware. We scaled      Themachinelearning community has seen several systems\r\n              its throughput against a calibration run of program counter     arise to address the dif\ufb01culty of hand-batching by batching\r\n              autobatching on the same machine and precision.                 user programs automatically. The extant ones have used\r\n              4.2   Batch utilization on correlated Gaussian                  one of two major architectures. The \ufb01rst is the local static\r\n                                                                              autobatching we described in Section 2. The second, called\r\n              Wealsomeasurethespeci\ufb01c effect of batching across recur-        dynamic batching, is exempli\ufb01ed by (Neubig et al., 2017)\r\n              sion depths. The NUTSalgorithmdynamicallychooseshow             and (Looks et al., 2017). In this architecture, the runtime\r\n              manygradient steps to take in each trajectory. When run-        performs batching dynamically, by running parallel evalua-\r\n              ning a multi-step Markov chain, one therefore has a choice      tions of the user program against a scheduler that manages\r\n              of whether to synchronize on trajectory boundaries or on        the execution and batches opportunistically. From the lens\r\n              individual gradient steps. Local autobatching systems can       of control \ufb02ow, the advantage of dynamic batching is its\r\n              only implement the former, because the control structure of     ability to recover more batching (including within a single\r\n              the whole batch computation necessarily follows the control     execution, if there is no data dependence). On the other\r\n              structure of the user program. Program counter autobatch-       hand, both local and program counter autobatching have\r\n              ing, however, can synchronize on the gradients, for example     less runtime overhead, because all the decisions about batch\r\n              evaluating the 5th gradient of the 3rd trajectory of one batch  scheduling are done statically (at batch-program extraction\r\n              memberintandemwiththe8thgradient of the 2nd trajec-             time). For the same reason, the latter two architectures are\r\n              tory of another. In the regime where the model gradients are    more amenable to running on top of an existing machine\r\n              expensive relative to the inter-trajectory book keeping, the    learning framework, whereas dynamic batching requires\r\n              latter should be preferable.                                    support from the underlying graph executor.\r\n              We\ufb01ndinFigure7thatonasyntheticcorrelated Gaussian,              The presentation in Section 2 gives one implementation\r\n              synchronizing on trajectory boundaries leaves as much as        style for the local static autobatching transformation. Other\r\n              a factor of 4 on the table even at a batch size as low as 30.   systems achieve the same effect in different ways.\r\n              Program counter autobatching is able to recover a factor of     Matchbox (Bradbury & Fu, 2018) relies on a relatively\r\n              about 2 in utilization in as few as 10 NUTS trajectories. We    lightweightsyntaxtransformationtointerceptifstatements\r\n              expect gradient utilization to approach 1 in the limit of long  and whileloops, and otherwise accomplishes batching by\r\n              chains as the per-chain distribution of total gradients taken   de\ufb01ning a \u201cbatched array\u201d type that carries the mask. The\r\n              approaches a Gaussian.                                          batched array overloads all the methods for a standard array\r\n                                                                        Autobatching\r\n               with appropriate additional masking. Recursion is left to the     prompted two rewrites in non-recursive form (Phan & Prad-\r\n               ambient Python stack. In our terms, the mask corresponds          han, 2019; Lao & Dillon, 2019) for the express purpose\r\n               to the active set.                                                of running it on accelerators more effectively. The non-\r\n               At if statements, Matchbox \ufb01rst executes the then arm             recursive form is also amenable to batching either by hand\r\n               (if any batch members need it) and then the else. The             or using an established autobatching system (whether local\r\n               program counter of Algorithm 1 is thus encoded in the             or dynamic). One would expect such a manual effort to\r\n               queue (also maintained on the Python stack) of mask-block         obtain better performance, but its labor-intensiveness neces-\r\n               pairs to be executed. The data structure is equivalent: one       sarily limits its scope.\r\n               vector of indices encodes the same information as a list          6    CONCLUSION\r\n               of pairs of index with exclusive mask of items having that\r\n               index. Storing the queue of program resumption points on          Wepresentedprogramcounterautobatching,anovelmethod\r\n               the Python stack makes it more dif\ufb01cult for Matchbox to           for automatically vectorizing batch computations at the ma-\r\n               use a different heuristic for the order in which to run blocks,   chine learning framework level. Program counter autobatch-\r\n               but the extant behavior is equivalent to the \u201crun the earliest    ing handles arbitrary control \ufb02ow in the source program,\r\n               block\u201d heuristic presented in Section 2.                          including batching operations across recursion depth. We\r\n               Jax (Bradbury et al., 2017\u20132019) relies on an explicit trac-      demonstrated the ef\ufb01cacy of the method by mechanically\r\n               ing pass to construct an internal representation, on which        batching a (recursive) implementation of the No U-Turn\r\n               batching (invoked via jax.vmap)isanexplicitstatic trans-          Sampler, obtaining speedups varying (with batch size) up\r\n               formation (one of several Jax can perform). Control \ufb02ow re-       to three orders of magnitude. An implementation of pro-\r\n               quires using the Jax-speci\ufb01c control operators: lax.cond          gram counter autobatching is available in the TensorFlow\r\n               instead of if and lax.while loop instead of while.                Probability package.\r\n               Recursion is not supported in Jax at all, because it confuses     ACKNOWLEDGMENTS\r\n               the tracer. There is therefore no stack. The program counter\r\n               is encoded in a mask and an execution sequence the same           The authors would like to thank Delesley Hutchins for in-\r\n               wayit is in Matchbox, with the same effects.                      valuable early critique of the architecture of the compiler;\r\n               Similarly, TensorFlow\u2019s pfor facility (Agarwal, 2019;             and the anonymous reviewers for their feedback.\r\n               Agarwal & Ganichev, 2019) operates on the TensorFlow\r\n               graph, including its tf.cond and tf.while loop con-               REFERENCES\r\n               trol operators. The transformation in pfor is the same as\r\n               Jax\u2019s vmap, up to implementation details. Recursion is            Abadi, M., Agarwal, A., Barham, P., Brevdo, E., Chen, Z.,\r\n               similarly not supported, because the underlying TensorFlow          Citro, C., Corrado, G. S., Davis, A., Dean, J., Devin, M.,\r\n               graph doesn\u2019t support it.                                           Ghemawat, S., Goodfellow, I., Harp, A., Irving, G., Isard,\r\n               This is the sense in which the transformation is \u201clocal\u201d:           M., Jia, Y., Jozefowicz, R., Kaiser, L., Kudlur, M., Lev-\r\n                                                                                                    \u00b4\r\n               this autobatching style (at least with this basic block choice      enberg, J., Mane, D., Monga, R., Moore, S., Murray, D.,\r\n               heuristic) perserves the nesting structure of the user\u2019s origi-     Olah, C., Schuster, M., Shlens, J., Steiner, B., Sutskever,\r\n               nal control constructs. As such, it can be implemented by           I., Talwar, K., Tucker, P., Vanhoucke, V., Vasudevan,\r\n                                                                                          \u00b4\r\n               a local transformation that, e.g., turns a while loop into a        V., Viegas, F., Vinyals, O., Warden, P., Wattenberg, M.,\r\n               similar loop with a transformed body.                               Wicke, M., Yu, Y., and Zheng, X. TensorFlow: Large-\r\n                                                                                   scale machine learning on heterogeneous systems, 2015.\r\n               Somewhatfarthera\ufb01eld, GPUprogramminglanguagessuch                   URLhttps://www.tensorflow.org/. Software\r\n               as CUDA(Nickolls et al., 2008) are also automatically vec-          available from tensor\ufb02ow.org.\r\n               torized. The handling of control constructs in CUDA is            Agarwal, A. Static automatic batching in TensorFlow. In\r\n               identical with local static autobatching, but of course only        Chaudhuri, K. and Salakhutdinov, R. (eds.), Proceedings\r\n               applies to kernels written therein. An interesting potential        of the 36th International Conference on Machine Learn-\r\n               direction for application-level automatic batching could be         ing, volume 97 of Proceedings of Machine Learning Re-\r\n               to forward surface language control constructs through a            search, pp. 92\u2013101, Long Beach, California, USA, 09\u2013\r\n               compiler targeting CUDA (such as the GPU backend of                 15 Jun 2019. PMLR. URL http://proceedings.\r\n               the XLAcompiler) and rely on CUDA to batch them. The                mlr.press/v97/agarwal19a.html.\r\n               ISPC compiler (Pharr & Mark, 2012) performs the same\r\n               automatic vectorization transform for vector units in CPUs.       Agarwal, A. and Ganichev, I. Auto-vectorizing TensorFlow\r\n               Finally, the NUTS algorithm is central enough to have               graphs: Jacobians, auto-batching and beyond, 2019. URL\r\n                                                                                   https://arxiv.org/abs/1903.04243.\r\n                                                                 Autobatching\r\n              Amos,B.andKolter, J. Z. Optnet: Differentiable optimiza-    Moldovan, D., Decker, J. M., Wang, F., Johnson, A. A.,\r\n                tion as a layer in neural networks. In Proceedings of       Lee, B. K., Nado, Z., Sculley, D., Rompf, T., and\r\n                the 34th International Conference on Machine Learning-      Wiltschko, A. B. Autograph: Imperative-style coding\r\n                Volume 70, pp. 136\u2013145. JMLR. org, 2017.                    with graph-based performance, 2018.     URL https:\r\n                                                                            //arxiv.org/abs/1810.08061.\r\n              Bradbury, J. and Fu, C.         Automatic batching as\r\n                a compiler pass in PyTorch.         In Workshop on        Neubig, G., Goldberg, Y., and Dyer, C. On-the-\ufb02y operation\r\n                Systems for ML, Dec 2018.            URL http://            batching in dynamic computation graphs. In Advances in\r\n                learningsys.org/nips18/assets/papers/                       Neural Information Processing Systems, pp. 3971\u20133981,\r\n                107CameraReadySubmissionMatchbox_                           2017.\r\n                _LearningSys_Abstract_(2).pdf.                            Nickolls, J., Buck, I., Garland, M., and Skadron, K. Scalable\r\n              Bradbury, J., Frostig, R., Hawkins, P., Johnson, M., Leary,   parallel programming with cuda. Queue, 6(2):40\u201353,\r\n                C., Maclaurin, D., and Wanderman-Milne, S. JAX, 2017\u2013       March 2008. ISSN 1542-7730. doi: 10.1145/1365490.\r\n                2019. URL https://github.com/google/jax.                    1365500. URLhttp://doi.acm.org/10.1145/\r\n                Speci\ufb01cally the vmap functionality.                         1365490.1365500.\r\n              Chen, T. Q., Rubanova, Y., Bettencourt, J., and Duve-       Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E.,\r\n                naud, D. K.    Neural ordinary differential equations.      DeVito, Z., Lin, Z., Desmaison, A., Antiga, L., and Lerer,\r\n                In Bengio, S., Wallach, H., Larochelle, H., Grau-           A. Automatic differentiation in PyTorch. In NIPS-W,\r\n                man, K., Cesa-Bianchi, N., and Garnett, R. (eds.),          2017.\r\n                Advances in Neural Information Processing Sys-            Phan, D. and Pradhan, N.         Iterative NUTS, May\r\n                tems 31, pp. 6571\u20136583. Curran Associates, Inc.,            2019.   URL https://github.com/pyro-ppl/\r\n                2018.   URL http://papers.nips.cc/paper/                    numpyro/wiki/Iterative-NUTS.\r\n                7892-neural-ordinary-differential-equations.\r\n                pdf.                                                      Pharr, M. and Mark, W. R.      ispc:  A spmd compiler\r\n                                                                            for high-performance cpu programming. In 2012 In-\r\n              Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, novative Parallel Computing (InPar), pp. 1\u201313. IEEE,\r\n                S., Moore, D., Patton, B., Alemi, A., Hoffman, M., and      2012.    URL http://doi.acm.org/10.1145/\r\n                Saurous, R. A. TensorFlow Distributions, 2017. URL          1133255.1133997.\r\n                https://arxiv.org/abs/1711.10604.\r\n                                                                          Silver, D., Huang, A., Maddison, C. J., Guez, A., Sifre, L.,\r\n              Jia, Y., Shelhamer, E., Donahue, J., Karayev, S., Long, J.,   van den Driessche, G., Schrittwieser, J., Antonoglou, I.,\r\n                Girshick, R., Guadarrama, S., and Darrell, T. Caffe:        Panneershelvam, V., Lanctot, M., Dieleman, S., Grewe,\r\n                Convolutional architecture for fast feature embedding.      D., Nham,J.,Kalchbrenner,N.,Sutskever,I.,Lillicrap, T.,\r\n                arXiv preprint arXiv:1408.5093, 2014. URL https:            Leach, M., Kavukcuoglu, K., Graepel, T., and Hassabis,\r\n                //arxiv.org/abs/1408.5093.                                  D. Mastering the game of go with deep neural networks\r\n                                                                            and tree search.  Nature, 529:484\u2013503, 2016.     URL\r\n              Jones, N. D., Gomard, C. K., and Sestoft, P. Partial evalu-   http://www.nature.com/nature/journal/\r\n                ation and automatic program generation. Prentice Hall       v529/n7587/full/nature16961.html.\r\n                International, 1993.\r\n                                                                          The TFP Team.          TensorFlow Probability,    2018\u2013\r\n              Lao,  J.  and Dillon,   J.  V.     Unrolled implemen-         2019. URLhttps://github.com/tensorflow/\r\n                tation   of   no-u-turn   sampler,    August     2019.      probability.\r\n                URL        https://github.com/tensorflow/\r\n                probability/blob/master/discussion/                       The XLA Team.       Xla\u2014TensorFlow, compiled, 2017.\r\n                technical_note_on_unrolled_nuts.md.                         URLhttps://developers.googleblog.com/\r\n                Software    contributed   to  TensorFlow     Probabil-      2017/03/xla-tensorflow-compiled.html.\r\n                ity   as   https://github.com/tensorflow/\r\n                probability/blob/master/tensorflow_\r\n                probability/python/mcmc/nuts.py.\r\n              Looks, M., Herreshoff, M., Hutchins, D., and Norvig,\r\n                P.  Deep learning with dynamic computation graphs.\r\n                arXiv preprint arXiv:1702.02181, 2017. URL https:\r\n                //arxiv.org/abs/1702.02181.\r\n", "award": [], "sourceid": 159, "authors": [{"given_name": "Alexey", "family_name": "Radul", "institution": "Google"}, {"given_name": "Brian", "family_name": "Patton", "institution": "Google Inc."}, {"given_name": "Dougal", "family_name": "Maclaurin", "institution": "Google Inc."}, {"given_name": "Matthew", "family_name": "Hoffman", "institution": "Google"}, {"given_name": "Rif", "family_name": "A. Saurous", "institution": "Google"}]}