Keyformer: KV Cache reduction through key tokens selection for Efficient Generative Inference

Part of Proceedings of Machine Learning and Systems 6 (MLSys 2024) Conference

Bibtex Paper Supplemental

Authors

Muhammad Adnan, Akhil Arunkumar, Gaurav Jain, Prashant Nair, Ilya Soloveychik, Purushotham Kamath

Abstract

Transformers have emerged as the standard architecture for Large Language Models (LLMs). In generativelanguage models, the inference process involves two main phases: prompt processing and token generation. Tokengeneration, which constitutes most of the computational load, primarily entails vector-matrix multiplicationsand interactions with the Key-Value ($\mathsf{KV}$) Cache. This phase is memory bandwidth-bound due to the overheadof transferring weights and KV cache values from memory to the computing units, which involves relativelylow compute intensity. This memory bottleneck becomes particularly prominent in applications that demandlong-context and extensive text generation, both of which are increasingly crucial for LLMs.This paper introduces an innovative approach to mitigate the challenges associated with KV cache size and memorybandwidth utilization, termed "$\mathsf{Keyformer}$". $\mathsf{Keyformer}$ capitalizes on the observation that during generativeinference, approximately 90% of the attention weight is concentrated on a select subset of tokens, which actas "key" tokens. $\mathsf{Keyformer}$’s key tokens identification takes into account the discarded tokens by utilizing anovel score function. By retaining only these "key" tokens in the $\mathsf{KV cache}$, both the $\mathsf{KV cache}$ size and memorybandwidth usage are significantly reduced while maintaining the model’s accuracy. We evaluate $\mathsf{Keyformer}$’seffectiveness using three foundational models: GPT-J, Cerebras-GPT, and MPT, which employ various positionalembedding algorithms. Our assessment covers a range of tasks, with a primary focus on summarization andconversation tasks that involve extended contexts. $\mathsf{Keyformer}$’s $\mathsf{KV cache}$ reduction enhances inference latencyby 2.1$\times$ and boosts token generation throughput by 2.4$\times$, all while preserving the model’s accuracy.