Paras Jain, Ajay Jain, Aniruddha Nrusimha, Amir Gholami, Pieter Abbeel, Joseph Gonzalez, Kurt Keutzer, Ion Stoica
Modern neural networks are increasingly bottlenecked by the limited capacity of on-device GPU memory. Prior work explores dropping activations as a strategy to scale to larger neural networks with fixed memory. However, these heuristics assume uniform cost per layer and only consider simple linear chain architectures, limiting their usability. In this paper, we formalize the problem of trading-off computation time and memory requirements for DNN training as the tensor rematerialization optimization problem. We develop a new system to optimally solve the problem in reasonable times (under an hour) using off-the-shelf MILP solvers. These schedules subsequently accelerate millions of training iterations. Our optimization pass in TensorFlow 2.0 automatically yields real training speedups of up to 4.8x over prior work, and can enable up to 5x increase in input size for real-world large networks.