[Misc] Remove patch files related to rotary quantization, and register new mtp/eagle3 model classes#10896
[Misc] Remove patch files related to rotary quantization, and register new mtp/eagle3 model classes#10896wangbj127 wants to merge 1 commit into
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the rotary quantization support for DeepSeek MTP models by moving away from monkey-patching via a dedicated patch file. Instead, it implements cleaner, class-based overrides within the 'vllm_ascend' model directory. This change improves code modularity and reduces reliance on global state manipulation, aligning with better software engineering practices for model extensions. Highlights
New Features🧠 You can now enable Memory (public preview) to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize the Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counterproductive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Code Review
This pull request refactors the codebase by removing the deprecated monkey-patching file patch_deepseek_mtp.py and instead registering the AscendDeepSeekMTP and AscendGlmMoeDsaForCausalLM models directly in the model registry. Feedback on this PR includes a suggestion to format the PR title and summary according to the repository style guide, as well as several robust weight-loading improvements: filtering out rot. weights when rotary weights are not used to prevent potential AttributeError or ValueError crashes, expanding the skip_prefixes in AscendGlmMoeDsaForCausalLM to handle fully qualified weight names, and ensuring is_rot_weight_used consistently returns a boolean value.
| ModelRegistry.register_model("DeepseekV4ForCausalLM", "vllm_ascend.models.deepseek_v4:AscendDeepseekV4ForCausalLM") | ||
|
|
||
| ModelRegistry.register_model("DeepSeekV4MTPModel", "vllm_ascend.models.deepseek_v4_mtp:DeepSeekV4MTP") | ||
| ModelRegistry.register_model("DeepSeekMTPModel", "vllm_ascend.models.deepseek_mtp:AscendDeepSeekMTP") |
There was a problem hiding this comment.
According to the Repository Style Guide, here is the suggested PR Title and PR Summary for this pull request:
Suggested PR Title:
[Ops][Misc] Delete patch files related to rotary quantization and register AscendDeepSeekMTP/AscendGlmMoeDsaForCausalLM modelsSuggested PR Summary:
### What this PR does / why we need it?
This PR deletes the deprecated patch files related to rotary quantization (`patch_deepseek_mtp.py`) and registers the new `AscendDeepSeekMTP` and `AscendGlmMoeDsaForCausalLM` models directly in the model registry. This refactoring improves maintainability by replacing runtime monkey-patching with clean subclassing and proper weight loading/mapping.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with existing tests.References
- Format the PR Title and PR Summary according to the Repository Style Guide. (link)
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | ||
| weights_mapper = WeightsMapper( | ||
| orig_to_new_prefix={ | ||
| "rot.": f"model.layers.{self.config.num_hidden_layers}.rot." | ||
| }, | ||
| ) | ||
| return super().load_weights(weights_mapper.apply(weights)) |
There was a problem hiding this comment.
If self.is_rot_weight_used is False, self.rot is not defined. However, if the checkpoint contains rot.weight, super().load_weights will attempt to load it into self.rot, resulting in an AttributeError or ValueError. We should filter out any rot. weights when self.is_rot_weight_used is False to prevent potential crashes.
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| weights_mapper = WeightsMapper( | |
| orig_to_new_prefix={ | |
| "rot.": f"model.layers.{self.config.num_hidden_layers}.rot." | |
| }, | |
| ) | |
| return super().load_weights(weights_mapper.apply(weights)) | |
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| if not self.is_rot_weight_used: | |
| weights = ((name, tensor) for name, tensor in weights if not name.startswith("rot.")) | |
| weights_mapper = WeightsMapper( | |
| orig_to_new_prefix={ | |
| "rot.": f"model.layers.{self.config.num_hidden_layers}.rot." | |
| }, | |
| ) | |
| return super().load_weights(weights_mapper.apply(weights)) |
| class AscendGlmMoeDsaForCausalLM(GlmMoeDsaForCausalLM): | ||
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | ||
| loader = AutoWeightsLoader(self, skip_prefixes=["rot."]) | ||
| return loader.load_weights(weights) |
There was a problem hiding this comment.
If the checkpoint contains model.layers.{num_hidden_layers}.rot.weight, skip_prefixes=["rot."] will not match it because it starts with "model.layers.". This will cause AutoWeightsLoader to fail with a ValueError due to unexpected keys. We should add f"model.layers.{self.config.num_hidden_layers}.rot." to skip_prefixes to ensure all variations of the rot weight are skipped.
| class AscendGlmMoeDsaForCausalLM(GlmMoeDsaForCausalLM): | |
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| loader = AutoWeightsLoader(self, skip_prefixes=["rot."]) | |
| return loader.load_weights(weights) | |
| class AscendGlmMoeDsaForCausalLM(GlmMoeDsaForCausalLM): | |
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| loader = AutoWeightsLoader( | |
| self, | |
| skip_prefixes=[ | |
| "rot.", | |
| f"model.layers.{self.config.num_hidden_layers}.rot.", | |
| ], | |
| ) | |
| return loader.load_weights(weights) |
| def is_rot_weight_used(vllm_config: VllmConfig = None): | ||
| global _IS_ROT_WEIGHT_USED | ||
| if vllm_config is None: | ||
| from vllm.config import get_current_vllm_config_or_none | ||
|
|
||
| vllm_config = get_current_vllm_config_or_none() | ||
| if _IS_ROT_WEIGHT_USED is None and vllm_config is not None: | ||
| quant_description = getattr(vllm_config.quant_config, "quant_description", None) | ||
| _IS_ROT_WEIGHT_USED = quant_description.get("is_rot_used", False) if quant_description is not None else False | ||
| return _IS_ROT_WEIGHT_USED |
There was a problem hiding this comment.
If vllm_config is None (e.g., when called without config and get_current_vllm_config_or_none() also returns None), the function returns None instead of a boolean, which can cause falsy/truthy confusion or type errors. We should return False as a safe fallback.
| def is_rot_weight_used(vllm_config: VllmConfig = None): | |
| global _IS_ROT_WEIGHT_USED | |
| if vllm_config is None: | |
| from vllm.config import get_current_vllm_config_or_none | |
| vllm_config = get_current_vllm_config_or_none() | |
| if _IS_ROT_WEIGHT_USED is None and vllm_config is not None: | |
| quant_description = getattr(vllm_config.quant_config, "quant_description", None) | |
| _IS_ROT_WEIGHT_USED = quant_description.get("is_rot_used", False) if quant_description is not None else False | |
| return _IS_ROT_WEIGHT_USED | |
| def is_rot_weight_used(vllm_config: VllmConfig = None) -> bool: | |
| global _IS_ROT_WEIGHT_USED | |
| if vllm_config is None: | |
| from vllm.config import get_current_vllm_config_or_none | |
| vllm_config = get_current_vllm_config_or_none() | |
| if _IS_ROT_WEIGHT_USED is None: | |
| if vllm_config is not None: | |
| quant_description = getattr(vllm_config.quant_config, "quant_description", None) | |
| _IS_ROT_WEIGHT_USED = quant_description.get("is_rot_used", False) if quant_description is not None else False | |
| else: | |
| return False | |
| return _IS_ROT_WEIGHT_USED |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
71bc39b to
f458198
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: Wangbj127 <wangbj1207@126.com>
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
What this PR does / why we need it?
patch_deepseek_mtp.pyandpatch_draft_quarot.pyare deprecated. Instead, two new model classes (AscendDeepSeekMTPandAscendEagle3LlamaForCausalLM) are registered.Does this PR introduce any user-facing change?
No.
How was this patch tested?