FlashDecoding++: Faster Large Language Model Inference with Asynchronization, Flat GEMM Optimization, and Heuristics

Part of Proceedings of Machine Learning and Systems 6 (MLSys 2024) Conference

Bibtex Paper

Authors

Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, kangdi chen, Yuhan Dong, Yu Wang

Abstract

As the Large Language Model (LLM) becomes increasingly important in various domains, the performance of LLM inference is crucial to massive LLM applications. However, the following challenges still remain unsolved in accelerating LLM inference: (1) Synchronized partial softmax update. The softmax operation requires a synchronized update operation among each partial softmax result, leading to ∼20% overheads for the attention computation in LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices performing GEMM in LLM inference is flat, leading to under-utilized computation and 50% performance loss after padding zeros in previous designs (e.g., cuBLAS, CUTLASS, etc.). (3) Performance loss to static dataflow. Kernel performance in LLM depends on varied input data features, hardware configurations, etc. A single and static dataflow may lead to 50.25% performance loss for GEMMs of different shapes in LLM inference.We present FlashDecoding++, a fast LLM inference engine supporting mainstream LLMs and hardware back- ends. To tackle the above challenges, FlashDecoding++ creatively proposes: (1) Asynchronized softmax with unified max value. FlashDecoding++ introduces a unified max value technique for different partial softmax computations to avoid synchronization. Based on this, the fine-grained pipelining is proposed, leading to 1.05× and 1.14× for the prefill and decoding stage in LLM inference, respectively. (2) Flat GEMM optimization with double buffering. FlashDecoding++ points out that flat GEMMs with different shapes face varied bottlenecks. Then, techniques like double buffering are introduced, leading up to 52% speedup for the flat GEMM operation. (3) Heuristic dataflow with hardware resource adaption. FlashDecoding++ heuristically optimizes dataflow using different hardware resource (e.g., Tensor Core or CUDA core) considering input dynamics. The design leads to up to 29% speedup compared with the static dataflow. Due to the versatility of optimizations in FlashDecoding++, FlashDecoding++ can achieve up to 4.86× and 2.18× speedup on both NVIDIA and AMD GPUs compared with Hugging Face implementations. FlashDecoding++ also achieves an average of 1.37× speedup compared with state-of-the-art LLM inference engines, FlashDecoding, on various LLMs (e.g., Llama2, ChatGLM2, etc.).