Phase 2: replace dict-style quant_config access with LayerQuantConfig#274
Phase 2: replace dict-style quant_config access with LayerQuantConfig#274
Conversation
66e940b to
873a969
Compare
8801fb5 to
0c9fd54
Compare
|
Is it possible to request a regression test for this pr or for the current main branch? |
56a445b to
3f0ac88
Compare
There was a problem hiding this comment.
Pull request overview
This PR completes the Phase 2 migration away from dict-style quantization config access by introducing a typed, immutable LayerQuantConfig and routing quant config parsing through a parser registry, then updating models/ops/docs/tests to use the new API.
Changes:
- Introduce
atom/quant_spec.pywithLayerQuantConfig(frozen dataclass),ParsedQuantConfig, and a quant-config parser registry (QuarkParser,GenericParser). - Refactor
QuantizationConfigto parse via the registry and expose typed per-layer resolution viaget_layer_quant_config(). - Update model/ops code and docs/tests to use attribute access (e.g.,
.quant_dtype) instead of dict keys.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
atom/quant_spec.py |
Adds typed quant spec + parser registry and built-in parsers. |
atom/config.py |
Switches quant parsing to registry + typed resolution; removes dict-based parsing paths. |
atom/models/llama.py |
Uses per-layer resolved quant type instead of global dict access. |
atom/models/deepseek_v2.py |
Replaces dict indexing with typed attribute access for resolved specs. |
atom/models/deepseek_mtp.py |
Replaces dict indexing with typed attribute access. |
atom/model_ops/moe.py |
Updates MoE methods to use typed LayerQuantConfig attributes. |
atom/model_ops/linear.py |
Updates LinearBase plumbing to use typed spec attributes. |
atom/model_ops/layernorm.py |
Updates RMSNorm to resolve per-layer spec via prefix. |
atom/model_ops/activation.py |
Updates SiluAndMul to resolve per-layer spec via prefix. |
tests/test_quant_config.py |
Updates/refactors tests for typed API and parser registry. |
docs/configuration_guide.md |
Updates docs to reflect LayerQuantConfig as a frozen dataclass and new parsing path. |
docs/architecture_guide.md |
Removes LayerQuantConfig from atom/config.py ownership list. |
Comments suppressed due to low confidence (1)
docs/configuration_guide.md:139
- The
LayerQuantConfigtable documentsquant_methodas typestrwith default"", but the implementation isquant_method: str | None = None. Update the docs to reflect that the field may beNoneand clarify whatNonevs an empty string means for downstream logic.
| Field | Type | Default | Description |
|---|---|---|---|
| `quant_type` | `QuantType` | `QuantType.No` | Quantization granularity (see below) |
| `quant_dtype` | `torch.dtype` | `torch.bfloat16` | Data type for quantized weights |
| `is_dynamic` | `bool` | `True` | Use dynamic quantization (scales computed at runtime) |
| `quant_method` | `str` | `""` | Quantization method (e.g., `"quark"`, `"compressed-tensors"`) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| @@ -99,7 +99,7 @@ def __init__( | |||
| self.act_fn = SiluAndMul( | |||
| fused_quant=self.fused_act_quant, quant_config=quant_config | |||
There was a problem hiding this comment.
SiluAndMul now resolves its quant spec from quant_config.get_layer_quant_config(prefix), but this call site does not pass prefix. In mixed per-layer quant setups (e.g., MXFP4 MLP), self.quant_type is resolved from the layer prefix and later used to decide whether to unpack (x, scale) from self.act_fn(...); if SiluAndMul was initialized with the wrong (global) spec it may return a tensor while the caller tries to unpack it, causing incorrect behavior or runtime errors. Pass the same layer prefix into SiluAndMul (or derive the condition from self.act_fn.quant_type) so the fused-quant return contract matches the caller logic.
| fused_quant=self.fused_act_quant, quant_config=quant_config | |
| fused_quant=self.fused_act_quant, | |
| quant_config=quant_config, | |
| prefix=prefix, |
| ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT | ||
| ) | ||
|
|
||
| self.quant_type = quant_config.global_quant_config["quant_type"] | ||
| self.quant_type = quant_config.get_layer_quant_config(prefix).quant_type |
There was a problem hiding this comment.
self.quant_type is now resolved per-layer via get_layer_quant_config(prefix). Any fused ops in this decoder layer that conditionally return (tensor, scale) based on quant type (e.g., RMSNorm/activation) need to be initialized with the same prefix as well; otherwise the module’s internal resolved quant type can diverge from self.quant_type, and the subsequent tuple-unpacking logic can break. Ensure the prefix is plumbed consistently to the fused submodules that resolve quant specs.
| quant_method_str = ( | ||
| layer_quant_config.quant_method if layer_quant_config else None | ||
| ) | ||
| if layer_quant_config.quant_type == QuantType.No: |
There was a problem hiding this comment.
layer_quant_config can be None when quant_config is None, but the code unconditionally accesses layer_quant_config.quant_type / .quant_dtype below. This will raise an AttributeError for unquantized MoE usage. Initialize layer_quant_config to a default LayerQuantConfig() when quant_config is None, or guard all later accesses and treat the None case as QuantType.No.
| if layer_quant_config.quant_type == QuantType.No: | |
| if layer_quant_config is None or layer_quant_config.quant_type == QuantType.No: |
| # 3. Pattern match | ||
| for pattern, spec in self._parsed.layer_pattern_specs: | ||
| if "*" not in pattern: | ||
| if layer_name in pattern: |
There was a problem hiding this comment.
In get_layer_quant_config, the non-glob pattern branch uses if layer_name in pattern:. This is reversed for the intended semantics (exact match or pattern being a substring/prefix of the layer name) and can both fail to match valid overrides and accidentally match short layer_name values (e.g., "l" matches "lm_head"). Replace this with an explicit equality check (layer_name == pattern) or a clear prefix/substring rule (pattern in layer_name) consistent with the documented resolution order.
| if layer_name in pattern: | |
| if pattern in layer_name: |
| def parse(self, hf_quant_config: dict) -> ParsedQuantConfig: | ||
| quant_method = hf_quant_config.get("quant_method", "") | ||
| config_str = str(hf_quant_config).lower() | ||
|
|
||
| quant_dtype = self._infer_dtype(hf_quant_config, config_str) | ||
| quant_type = self._infer_qtype(hf_quant_config, config_str) | ||
| is_dynamic = hf_quant_config.get("is_dynamic", True) | ||
| exclude = list(hf_quant_config.get("exclude", []) or []) | ||
|
|
||
| global_spec = LayerQuantConfig( | ||
| quant_type=quant_type, | ||
| quant_dtype=quant_dtype, | ||
| is_dynamic=is_dynamic, | ||
| quant_method=quant_method or None, | ||
| ) | ||
|
|
||
| return ParsedQuantConfig(global_spec=global_spec, exclude_layers=exclude) | ||
|
|
There was a problem hiding this comment.
GenericParser.parse() only reads excluded layers from the exclude key. The previous implementation (and current docs) reference ignore as the exclude-list key for compressed-tensors/other formats, so this change will silently stop honoring excludes for those configs. Consider supporting both keys (e.g., exclude with fallback to ignore) for backward compatibility.
|
|
||
| **For quark models** (`quant_method == "quark"`): | ||
|
|
||
| 1. Parses `global_quant_config` dict via `parse_quark_config_dict()` to produce the global `LayerQuantConfig`. | ||
| 1. Parses `global_quant_config` dict via `QuarkParser` to produce the global `LayerQuantConfig`. | ||
| 2. Parses each entry in `layer_quant_config` dict to produce per-layer overrides. | ||
| 3. Reads the `"exclude"` list for excluded layers. | ||
| 4. Within each config dict, `weight.qscheme` determines `quant_type` (`"per_channel"` → `per_Token`, `"per_tensor"` → `per_Tensor`, `"per_group"` → `per_1x32`), and `weight.dtype` determines `quant_dtype`. |
There was a problem hiding this comment.
This quark parsing section was updated to reference QuarkParser, but the immediately following “For other models” section in this doc still describes the old parse_other_config heuristics and says excludes come from the "ignore" key. Since QuantizationConfig now delegates non-quark parsing to GenericParser (which currently reads exclude, not ignore), the documentation is now internally inconsistent—please update the non-quark section (or adjust GenericParser) to match.
| def parse(self, hf_quant_config: dict) -> ParsedQuantConfig: | ||
| quant_method = hf_quant_config.get("quant_method", "") | ||
| config_str = str(hf_quant_config).lower() | ||
|
|
||
| quant_dtype = self._infer_dtype(hf_quant_config, config_str) | ||
| quant_type = self._infer_qtype(hf_quant_config, config_str) | ||
| is_dynamic = hf_quant_config.get("is_dynamic", True) | ||
| exclude = list(hf_quant_config.get("exclude", []) or []) | ||
|
|
There was a problem hiding this comment.
GenericParser.parse() introduces new heuristics for non-quark quant configs, but current tests only cover registry lookup and Quark parsing. Add unit tests that exercise GenericParser parsing (dtype/qtype inference and exclude-list handling) to prevent regressions for compressed-tensors/GPTQ/AWQ-style configs.
|
@thpereir |
dffe673 to
3389cee
Compare
… for per-layer quant config Introduce atom/quant_spec.py with: - LayerQuantConfig: frozen dataclass with typed attribute access (quant_type, quant_dtype, is_dynamic, quant_method) replacing the old dict-based LayerQuantConfig(dict) subclass - ParsedQuantConfig: structured output of HF config parsing - Parser registry (@register_quant_parser) with QuarkParser and GenericParser (fallback for compressed-tensors, GPTQ, AWQ, etc.) Refactor QuantizationConfig (atom/config.py): - Internal storage now uses ParsedQuantConfig via parser registry - get_layer_quant_config(prefix) -> LayerQuantConfig (frozen dataclass) - global_quant_config property -> LayerQuantConfig - Convenience properties: quant_type, quant_dtype, is_dynamic - compute_hash() uses typed internal structures Migrate all consumers to typed attribute access: - linear.py: layer_quant_config.quant_type instead of ["quant_type"] - moe.py: all MoE method classes use LayerQuantConfig type hints - activation.py, layernorm.py: accept prefix param, use get_layer_quant_config() instead of bypassing with global_quant_config - deepseek_mtp.py, deepseek_v2.py, llama.py: use get_layer_quant_config() Fix GenericParser exclude-layer key handling (atom/quant_spec.py): - Different quantizers use different keys for excluded layers: compressed-tensors uses "ignore", gpt-oss/HF uses "modules_to_not_convert", Quark uses "exclude" - GenericParser now tries all three keys in priority order so excluded layers are never silently treated as quantized Fix hard-coded quant_config=None across models: - gpt_oss.py OAIAttention: qkv_proj and o_proj were passing quant_config=None, preventing fp8/mxfp4 quantization on attention projections in Quark gpt-oss models (e.g. fp8 qkv + mxfp4 MoE); both now receive quant_config - deepseek_v2.py Indexer: weights_proj passed quant_config=None while sibling linears wq_b and wk correctly used quant_config; fixed for consistency - qwen3_next.py GatedDeltaNet: conv1d ColumnParallelLinear omitted quant_config while other linears in the same class passed it; fixed
3389cee to
ca3e23e
Compare
…cross all models and ops
Motivation
Following up DeepSeek R1 enablement with per layer quantization support this converts all other models to use the same structured quantization config instead of a dict.
This PR depends on both #236 and #268
Technical Details
resolve(prefix)handles exclude-list, per-layer overrides, pattern matching, and global fallback in one call with documented priority. Less error prone than previous approach of accessing the dict. What wasquant_config.get_layer_quant_config(f"{prefix}.fused_qkv_a_proj")becomesquant_config.resolve(prefix)QuantizationConfig.__init__directly.Test Plan
Test Result
Submission Checklist