-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Add validation for missing model weights during checkpoint loading #16235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add validation for missing model weights during checkpoint loading #16235
Conversation
Summary of ChangesHello @davegolland, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial validation mechanism to SGLang's model loading process, preventing a critical silent bug where models could load with randomly initialized parameters if their corresponding weights were absent from the provided checkpoint. By ensuring that all expected model parameters are properly loaded, this change significantly enhances model reliability and integrity, particularly when dealing with diverse model architectures, evolving versions, or custom configurations. The new system provides clear warnings or errors, allowing users and developers to quickly identify and address potential inconsistencies between model architecture and checkpoint data, thereby averting subtle but impactful performance degradation or incorrect outputs. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This is an excellent pull request that addresses a critical silent bug in model weight loading. The addition of validation for missing weights is a crucial improvement for model correctness and stability. The implementation is well-structured, and the inclusion of comprehensive unit tests is commendable. I've identified a critical issue in the fused parameter mapping logic within loader.py concerning both correctness and performance, and a high-severity bug in weight_utils.py related to identifying parameters in pipeline parallel stages. My suggestions aim to fix these issues while improving efficiency. Overall, this is a very valuable contribution.
a6e5d11 to
fbc278e
Compare
## Problem When loading model checkpoints, if a parameter exists in the model architecture but is missing from the checkpoint file, it retains its random initialization values. This bug silently produces incorrect model outputs without any error or warning. This issue has existed since the initial codebase and affects all 130+ supported model architectures. ## Root Cause The `load_weights()` method in model classes iterates through weights from the checkpoint and loads them into the model. However, there's no validation to ensure ALL model parameters were loaded. Parameters missing from the checkpoint are never encountered in the loop, so they keep their random initial values. Example scenario: - Model instantiated with 25 layers (layers 0-24) - Checkpoint only has 24 layers (layers 0-23) - Layer 24 (the last layer) has random weights → garbage output! - No error or warning is raised ## Solution This PR adds centralized weight loading validation in the model loader: 1. **Track loaded parameters**: Wrap the weight iterator to record which parameters are loaded from the checkpoint 2. **Validate completeness**: After loading, check that all required model parameters were loaded 3. **Handle legitimate exceptions**: Exclude parameters that are expected to be missing: - Computed parameters (rotary_emb.inv_freq, cos_sin_cache) - Tied weights (lm_head.weight when tied to embed_tokens.weight) - Pipeline parallelism placeholders (PPMissingLayer) - Fused parameters (qkv_proj loaded from q/k/v_proj) 4. **Opt-in strict mode**: Controlled via environment variable `SGLANG_STRICT_WEIGHT_LOADING=1` to avoid breaking existing deployments ## Changes - `weight_utils.py`: New `validate_loaded_weights()` function - `loader.py`: Modified `load_weights_and_postprocess()` to track and validate - `test_missing_weights_validation.py`: Unit tests for validation logic ## Testing - Unit tests cover all edge cases (tied weights, computed params, etc.) - Validation runs on every model load when enabled - Default behavior unchanged (warn only) to maintain compatibility ## Historical Context This bug was implicitly exposed by commit a3339d8 (Feb 2025) which fixed tied weight loading errors - that fix was needed because the lack of validation meant silent failures went unnoticed. Git history shows no previous attempts at strict weight validation, likely due to the complexity of edge cases (fused weights, tied embeddings, pipeline parallelism).
Address PR review feedback to improve both performance and correctness: 1. Fused parameter mapping (loader.py): - Fix O(n*m) performance issue by pre-building mapping dictionary - Fix string matching bugs with robust dotted-path component matching - Extract FUSED_PARAMS_MAPPING as class constant to avoid duplication - Use replace(count=1) to prevent multiple replacement edge cases 2. Pipeline parallelism validation (weight_utils.py): - Fix prefix matching bug where "layers.1" incorrectly matched "layers.10" - Improve from O(modules * params) to O(params) complexity - Add "." suffix to module prefixes for exact matching 3. Add comprehensive test coverage: - test_pp_missing_layer_prefix_correctness: Validates exact prefix matching - test_loader_fused_params.py: Tests qkv_proj and gate_up_proj mapping - Ensures correctness fixes are validated and prevent regression
fbc278e to
c90fb97
Compare
Add validation for missing model weights during checkpoint loading
Overview
This PR fixes a critical silent bug where model parameters that exist in the model architecture but are missing from checkpoint files retain random initialization values, causing incorrect model outputs without any error or warning.
The Problem
What Happens
When
load_weights()iterates through checkpoint weights, it only processes weights that exist in the checkpoint. If a model parameter is missing from the checkpoint, the iteration never encounters it, so it keeps its random initialization value.Real-World Scenarios
This bug can occur when:
config.jsonsays 25 layers, checkpoint has 24Impact
The Solution
Approach
Add centralized validation in
loader.pythat:rotary_emb.inv_freq,cos_sin_cache)lm_head.weightwhen tied toembed_tokens.weight)PPMissingLayer)qkv_projloaded from separateq_proj/k_proj/v_proj)Key Features
load_weights_and_postprocess()SGLANG_STRICT_WEIGHT_LOADING=1environment variableImplementation Details
Files Changed
1.
python/sglang/srt/model_loader/weight_utils.pyAdded
validate_loaded_weights()function:Handles these edge cases:
lm_head.weightwhenconfig.tie_word_embeddings=Truerotary_emb.*,projectorpatternspersistent=FalsePPMissingLayerparameters2.
python/sglang/srt/model_loader/loader.pyModified
load_weights_and_postprocess():3.
test/srt/test_missing_weights_validation.pyNew unit tests:
Usage
Default Behavior (Non-Strict)
If weights are missing, you'll see:
Strict Mode (Recommended for Production)
export SGLANG_STRICT_WEIGHT_LOADING=1 python -m sglang.launch_server --model-path /path/to/modelIf weights are missing, server startup will fail with the above error message.
Testing
Unit Tests
All tests pass, covering:
Integration Testing
The validation automatically runs for all models loaded through SGLang.
To test with a real model:
Historical Context
Git History Analysis
6b0af2853): Original implementation had no validationa3339d8ca): Fixed bug with tied weights - symptom of missing validation11553c1a3): Added warning for extra weights in checkpointWhy This Wasn't Caught Earlier
Migration Guide
For Users
No action needed! Validation defaults to warn-only mode.
Recommended: Enable strict mode for production deployments:
export SGLANG_STRICT_WEIGHT_LOADING=1For Developers
If you maintain custom model classes, ensure your
load_weights()method:The validation will catch any issues during development!
Performance Impact
Future Work
Potential enhancements (out of scope for this PR):
Related Issues
This PR addresses the root cause of several potential issues:
Checklist
References
strict=Trueby default inload_state_dict()load_state_dict()returns(missing_keys, unexpected_keys)Ready for review! 🎯
This is a minimal, focused change that fixes a critical silent bug while maintaining backward compatibility.