Skip to content

Commit 4e08ef6

Browse files
committed
fix ci with meta amax
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent c149968 commit 4e08ef6

2 files changed

Lines changed: 29 additions & 2 deletions

File tree

modelopt/torch/quantization/utils/shared_input.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,13 @@ def populate_shared_state(model: nn.Module) -> int:
340340
parallel_state: ParallelState | None = None
341341
for child in members:
342342
wq = getattr(child, wq_attr, None)
343-
if wq is None or getattr(wq, "_amax", None) is None:
343+
amax = getattr(wq, "_amax", None) if wq is not None else None
344+
# Skip uncalibrated or meta (no-data) amax. A meta amax — e.g. quantizing an
345+
# ``init_empty_weights`` model before dispatch — would make weight_global_amax a
346+
# meta buffer that then breaks the meta->device ``.to()`` (it needs ``to_empty``).
347+
if amax is None or amax.is_meta:
344348
continue
345-
child_maxes.append(reduce_amax(wq._amax, axis=None))
349+
child_maxes.append(reduce_amax(amax, axis=None))
346350
if parallel_state is None:
347351
parallel_state = getattr(child, "parallel_state", None)
348352

tests/unit/torch/quantization/test_shared_input.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,29 @@ def test_populate_writes_max_across_siblings(self):
171171
assert shared is not None
172172
assert torch.isclose(shared, torch.tensor(3.0)), f"expected 3.0, got {shared.item()}"
173173

174+
def test_populate_skips_meta_amax(self):
175+
"""Meta (no-data) ``_amax`` must not become a meta ``weight_global_amax`` buffer.
176+
177+
Quantizing an ``init_empty_weights`` model produces meta ``_amax``; aggregating it
178+
would make ``weight_global_amax`` a meta buffer that breaks the later meta->device
179+
``.to()`` during dispatch. The group is skipped instead, leaving the buffer ``None``.
180+
"""
181+
attn = _DummyAttention()
182+
mtq.replace_quant_module(attn)
183+
cfg = _make_nvfp4_static_cfg()
184+
for proj in (attn.q_proj, attn.k_proj, attn.v_proj):
185+
proj.weight_quantizer.set_from_attribute_config(cfg)
186+
out_features, in_features = proj.weight.shape
187+
proj.weight_quantizer._amax = torch.empty(
188+
(out_features, in_features // NVFP4_BLOCK), device="meta"
189+
)
190+
191+
attach_shared_quant_states(attn, patterns=SIBLING_PATTERNS) # groups q/k/v (no amax needed)
192+
n_groups = populate_shared_state(attn)
193+
194+
assert n_groups == 0 # nothing real to aggregate
195+
assert attn._shared_quant_state.weight_global_amax is None # not a meta tensor
196+
174197
def test_promote_ignores_shared_state_outside_root(self):
175198
"""Promoting a submodule must ignore a back-ref whose owning state is outside it.
176199

0 commit comments

Comments
 (0)