Skip to content

Commit 52b9334

Browse files
committed
_replace_linears_for_quant_paths runs, projection paths point at StreamingQLinear modules.
1 parent 2b3d943 commit 52b9334

2 files changed

Lines changed: 75 additions & 0 deletions

File tree

src/kvboost/streaming/awq_loader.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def materialize_into_module(
374374
hf_model: "torch.nn.Module",
375375
*,
376376
only_resident: bool = True,
377+
skip_quant_projections: bool = True,
377378
) -> None:
378379
"""Write resident tensors into the matching submodules of ``hf_model``.
379380
@@ -385,12 +386,31 @@ def materialize_into_module(
385386
This is the bridge that lets us use ``accelerate.init_empty_weights``
386387
for the skeleton and then selectively materialize only the layers that
387388
should live in VRAM.
389+
390+
``skip_quant_projections`` (default True) skips ``*.qweight``,
391+
``*.scales``, ``*.qzeros`` tensors. When the streaming pipeline
392+
replaces projection modules with :class:`StreamingQLinear` (which
393+
has no ``qweight`` parameter slot), naively assigning those tensors
394+
via ``setattr`` creates **orphaned duplicate** allocations — they
395+
sit on the new module as bare attributes while the real binding
396+
happens later via :meth:`bind_streaming_qlinears` into the
397+
``_qweight`` / ``_scales`` / ``_qzeros`` slots. Skipping them here
398+
avoids that double-allocation (~2 GiB for a 32B model with 8
399+
resident layers).
388400
"""
389401
assert self.index is not None
390402

403+
def _is_quant_proj(name: str) -> bool:
404+
return (
405+
name.endswith(".qweight")
406+
or name.endswith(".scales")
407+
or name.endswith(".qzeros")
408+
)
409+
391410
wanted = [
392411
spec for spec in self.index.tensors.values()
393412
if (spec.is_resident if only_resident else True)
413+
and not (skip_quant_projections and _is_quant_proj(spec.name))
394414
]
395415

396416
by_shard: dict[Path, list[TensorSpec]] = {}

tests/streaming/test_mps_path.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,61 @@ def test_bound_streaming_qlinear_runs_forward(tmp_path):
224224
assert out.shape == (2, 16)
225225

226226

227+
def test_materialize_skips_quant_projections_by_default(tmp_path):
228+
"""Regression: ``materialize_into_module`` must not write the packed
229+
AWQ tensors (``*.qweight``/``*.scales``/``*.qzeros``) onto submodules.
230+
Those tensors are bound separately via ``bind_streaming_qlinears`` into
231+
the StreamingQLinear's ``_qweight`` etc. slots. If materialize also
232+
assigns them via ``setattr``, we get a second on-device copy that
233+
nothing reads — pure VRAM waste (~2 GiB on a 32B model).
234+
"""
235+
loader = _make_loader(tmp_path, num_layers=2)
236+
model = _FakeHfModel(num_layers=2)
237+
238+
# Replace projections with StreamingQLinear (this is what the streaming
239+
# path does before calling materialize).
240+
replacements = _replace_linears_for_quant_paths(
241+
model, loader=loader, group_size=8, prefer="torch",
242+
)
243+
244+
# Materialize with the default (skip_quant_projections=True).
245+
loader.materialize_into_module(model, only_resident=False)
246+
247+
# Verify the StreamingQLinears were NOT polluted with bare-attribute
248+
# qweight/scales/qzeros tensors. Their packed slots should still be
249+
# empty (only ``bind_streaming_qlinears`` fills those).
250+
for layer_idx, layer_repl in replacements.items():
251+
for sub_path, qlin in layer_repl.items():
252+
# _qweight is the bound slot; should be None pre-bind.
253+
assert qlin._qweight is None, (
254+
f"layer {layer_idx} {sub_path}: _qweight set without bind"
255+
)
256+
# And NO bare-attribute pollution either.
257+
for attr in ("qweight", "scales", "qzeros"):
258+
assert not hasattr(qlin, attr) or getattr(qlin, attr) is None, (
259+
f"layer {layer_idx} {sub_path}: bare .{attr} attr leaked"
260+
)
261+
262+
263+
def test_materialize_with_skip_disabled_does_assign_projections(tmp_path):
264+
"""The opt-out exists for callers that genuinely want to write
265+
projections via setattr (e.g. when the target is autoawq's
266+
WQLinear_GEMM with real qweight buffers).
267+
"""
268+
loader = _make_loader(tmp_path, num_layers=1)
269+
# Use a raw _FakeHfModel WITHOUT replacing — q_proj is plain nn.Linear,
270+
# which doesn't have a qweight slot. setattr should set a bare attribute.
271+
model = _FakeHfModel(num_layers=1)
272+
273+
loader.materialize_into_module(
274+
model, only_resident=False, skip_quant_projections=False,
275+
)
276+
277+
# Projections should now have qweight as a bare attribute (since the
278+
# plain Linear doesn't pre-declare one).
279+
assert hasattr(model.model.layers[0].self_attn.q_proj, "qweight")
280+
281+
227282
def test_iter_decoder_layers_on_fake_skeleton():
228283
model = _FakeHfModel(num_layers=4)
229284
pairs = _iter_decoder_layers(model)

0 commit comments

Comments
 (0)