Daniel Kang, Deepti Raghavan, Peter Bailis, Matei Zaharia
Machine learning models are increasingly deployed in mission-critical settings such as vehicles, but unfortunately, these models can fail in complex ways. To prevent errors, ML engineering teams monitor and continuously improve these models. We propose a new abstraction, model assertions, that adapts the classical use of program assertions as a way to monitor and improve ML models. Model assertions are arbitrary functions over the model's input and output that indicates when errors may be occurring. For example, a developer may write an assertion that an object's class should stay the same across frames of video. Once written, these assertions can be used both for runtime monitoring and for improving a model at training time. In particular, we show that at runtime, model assertions can find high confidence errors, where a model returns the wrong output with high confidence, which uncertainty-based monitoring techniques would not detect. We also propose two methods to use model assertions at training time. First, we propose a bandit-based active learning algorithm that can sample from data flagged by assertions and show that it can reduce labeling costs by up to 33% over traditional uncertainty-based methods. Second, we propose an API for generating "consistency assertions" (e.g., the class change example) and weak labels for inputs where the consistency assertions fail, and show that these weak labels can improve relative model quality by up to 46%. We evaluate both algorithms on four real-world tasks with video, LIDAR, and ECG data.