Scaling Deep Learning Training with MPMD Pipeline Parallelism

Anxhelo Xhebraj, Sean Lee, Hanfeng Chen, Vinod Grover

Proceedings of Machine Learning and Systems 7 (MLSys 2025) Conference

We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to $1.16\times$ with respect to the best performing SPMD configuration.