Skip to content

Commit 5cfb71c

Browse files
committed
speculative decoding bugs
1 parent f34904b commit 5cfb71c

1 file changed

Lines changed: 91 additions & 50 deletions

File tree

src/kvboost/streaming/kernels/marlin.py

Lines changed: 91 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -49,75 +49,113 @@ def _try_resolve(candidates: tuple[tuple[str, str], ...]) -> Optional[Callable[.
4949
return None
5050

5151

52-
_GEMM_FN: Optional[Callable[..., Any]] = _try_resolve(_GEMM_CANDIDATES)
52+
def _resolve_gemm(
53+
candidates: tuple[tuple[str, str], ...]
54+
) -> tuple[Optional[Callable[..., Any]], bool]:
55+
"""Resolve the GEMM fn AND whether it consumes the Marlin-repacked layout.
56+
57+
Only vLLM's ``awq_marlin_gemm`` reads the repacked layout; the autoawq
58+
kernels (``gemm_forward_cuda`` / ``awq_gemm``) want the ORIGINAL AWQ packing,
59+
so repacking the weights under them yields garbage. We track this so the
60+
loader's repack step can no-op for the raw-layout kernels.
61+
"""
62+
for module_name, attr in candidates:
63+
try:
64+
mod = __import__(module_name, fromlist=[attr])
65+
except Exception:
66+
continue
67+
fn = getattr(mod, attr, None)
68+
if fn is not None:
69+
needs_repack = attr == "awq_marlin_gemm"
70+
logger.debug(
71+
"resolved %s.%s for AWQ GEMM (marlin_layout=%s)",
72+
module_name, attr, needs_repack,
73+
)
74+
return fn, needs_repack
75+
return None, False
76+
77+
78+
_GEMM_FN, _GEMM_NEEDS_REPACK = _resolve_gemm(_GEMM_CANDIDATES)
5379
_REPACK_FN: Optional[Callable[..., Any]] = _try_resolve(_REPACK_CANDIDATES)
5480

5581

5682
# ── Probe the kernel's call signature once at load time ──────────────────────
57-
# All known AWQ GEMM kernels use (x, qweight, qzeros, scales, last_arg) where
58-
# last_arg is either split_k_iters (autoawq style) or group_size (vLLM Marlin
59-
# style). We probe with tiny tensors here and cache the working call so the
60-
# hot forward path never pays try/except overhead.
83+
# AWQ int4 GEMM kernels DISAGREE on the call layout, so we can't assume one:
84+
# * autoawq awq_ext.gemm_forward_cuda : (x, qw, SCALES, ZEROS, split_k_iters)
85+
# * vLLM awq_gemm : (x, qw, ZEROS, SCALES, split_k_iters)
86+
# * vLLM awq_marlin_gemm : (x, qw, ZEROS, SCALES, group_size)
87+
# They differ in BOTH the scales/zeros order AND the trailing int. We probe each
88+
# combination with tiny tensors and cache the first that runs. Because scales is
89+
# fp16 and zeros is int32, the WRONG scales/zeros order hits a kernel dtype check
90+
# and raises — so try/except reliably discriminates the order.
91+
# (An earlier version hard-coded only the vLLM (zeros, scales) order, which
92+
# silently disabled autoawq's awq_ext on every box that had it -> the slow torch
93+
# dequant fallback. That was the bug.)
94+
95+
_SPLIT_K_ITERS = 8 # autoawq's default; tunes K-dim reduction parallelism
96+
97+
# (label, scales_first, last_kind), most-preferred first. ``scales_first``
98+
# selects the autoawq (True) vs vLLM (False) order of the scales/zeros pair.
99+
_GEMM_SIG_CANDIDATES = (
100+
("autoawq (x,qw,scales,zeros,split_k)", True, "split_k"),
101+
("vllm (x,qw,zeros,scales,split_k)", False, "split_k"),
102+
("vllm (x,qw,zeros,scales,group_size)", False, "group_size"),
103+
("autoawq (x,qw,scales,zeros,group_size)", True, "group_size"),
104+
)
61105

62-
_SPLIT_K_ITERS = 8 # autoawq's default; tunes K-dim parallelism
106+
107+
def _make_gemm_caller(scales_first: bool, last_kind: str) -> Callable[..., Any]:
108+
"""Wrap _GEMM_FN so callers invoke it canonically as
109+
``caller(x_2d, qweight, qzeros, scales, group_size)`` regardless of the
110+
kernel's native scales/zeros order or trailing-int convention."""
111+
def _call(x_2d, qw, qz, sc, group_size): # noqa: ANN001
112+
last = _SPLIT_K_ITERS if last_kind == "split_k" else group_size
113+
if scales_first:
114+
return _GEMM_FN(x_2d, qw, sc, qz, last)
115+
return _GEMM_FN(x_2d, qw, qz, sc, last)
116+
return _call
63117

64118

65119
def _probe_gemm_signature() -> Optional[Callable[..., Any]]:
66-
"""Return a zero-argument callable that calls the resolved GEMM fn with the
67-
correct arg order, or None if no kernel is available or the probe fails.
68-
69-
All known kernels share the layout:
70-
fn(x_2d, qweight, qzeros, scales, last_arg)
71-
where last_arg is split_k_iters (int) or group_size (int).
72-
The old code had scales/qzeros SWAPPED on the first try, causing a
73-
RuntimeError on every forward that silently fell through to the slow
74-
torch dequant path.
120+
"""Return a caller invoking the resolved GEMM fn with the correct arg order,
121+
or None if no kernel is available or every known signature fails the probe.
75122
"""
76123
if _GEMM_FN is None:
77124
return None
78125

79126
try:
80127
import torch as _torch
81-
# Minimal tensors: in=128 (one group), out=16 (× pack=8 → 128 packed).
82-
group_size_probe = 128
83-
in_f, out_f = group_size_probe, 16
84128
device = _torch.device("cuda" if _torch.cuda.is_available() else "cpu")
85129
if device.type != "cuda":
86130
return None
87131

88-
x_p = _torch.zeros(1, in_f, dtype=_torch.float16, device=device)
89-
qw_p = _torch.zeros(in_f, out_f, dtype=_torch.int32, device=device)
90-
scales_p = _torch.ones(1, out_f * 8, dtype=_torch.float16, device=device)
91-
qzeros_p = _torch.zeros(1, out_f, dtype=_torch.int32, device=device)
92-
93-
# Try split_k_iters style (autoawq / awq_ext).
94-
try:
95-
_GEMM_FN(x_p, qw_p, qzeros_p, scales_p, _SPLIT_K_ITERS)
96-
logger.debug("marlin/awq GEMM: using split_k_iters signature")
97-
98-
def _call(x_2d, qw, qz, sc, *_): # noqa: ANN001
99-
return _GEMM_FN(x_2d, qw, qz, sc, _SPLIT_K_ITERS)
100-
101-
return _call
102-
except (RuntimeError, TypeError):
103-
pass
104-
105-
# Try group_size style (vLLM awq_marlin_gemm).
106-
try:
107-
_GEMM_FN(x_p, qw_p, qzeros_p, scales_p, group_size_probe)
108-
logger.debug("marlin/awq GEMM: using group_size signature")
109-
110-
def _call(x_2d, qw, qz, sc, group_size): # noqa: ANN001
111-
return _GEMM_FN(x_2d, qw, qz, sc, group_size)
112-
113-
return _call
114-
except (RuntimeError, TypeError):
115-
pass
132+
# Minimal valid AWQ shapes: K=256 (2 groups of 128), N=256 (=32 × pack 8).
133+
group_size_probe = 128
134+
in_f, out_f = 256, 32
135+
n_groups = in_f // group_size_probe
136+
x_p = _torch.zeros(1, in_f, dtype=_torch.float16, device=device)
137+
qw_p = _torch.zeros(in_f, out_f, dtype=_torch.int32, device=device)
138+
scales_p = _torch.ones(n_groups, out_f * 8, dtype=_torch.float16, device=device)
139+
qzeros_p = _torch.zeros(n_groups, out_f, dtype=_torch.int32, device=device)
140+
141+
for label, scales_first, last_kind in _GEMM_SIG_CANDIDATES:
142+
caller = _make_gemm_caller(scales_first, last_kind)
143+
try:
144+
out = caller(x_p, qw_p, qzeros_p, scales_p, group_size_probe)
145+
except (RuntimeError, TypeError):
146+
continue
147+
# Guard against a silently-accepted wrong layout: the output must be
148+
# (M, out_features) and finite.
149+
if out.shape[0] != 1 or out.shape[-1] != out_f * 8 \
150+
or not _torch.isfinite(out).all():
151+
continue
152+
logger.info("marlin/awq GEMM: using %s signature", label)
153+
return caller
116154

117155
logger.warning(
118-
"marlin/awq GEMM fn %r: neither split_k_iters nor group_size "
119-
"signature worked during probe — disabling kernel. AWQ will use "
120-
"ExLlamaV2 or the torch dequant fallback.",
156+
"marlin/awq GEMM fn %r: no known signature worked during probe "
157+
"(tried autoawq + vLLM orders x split_k/group_size) — disabling "
158+
"kernel. AWQ will use ExLlamaV2 or the torch dequant fallback.",
121159
_GEMM_FN,
122160
)
123161
return None
@@ -181,7 +219,10 @@ def awq_marlin_repack(
181219
"""Repack an AWQ ``qweight`` into Marlin's layout, if a repack kernel
182220
is available. Falls back to returning the input contiguous if not.
183221
"""
184-
if _REPACK_FN is None:
222+
# Repack ONLY when the resolved GEMM actually consumes the Marlin layout
223+
# (vLLM awq_marlin_gemm). The autoawq raw-layout kernels must keep the
224+
# ORIGINAL AWQ packing, or the GEMM reads garbage.
225+
if _REPACK_FN is None or not _GEMM_NEEDS_REPACK:
185226
return qweight.contiguous()
186227
try:
187228
return _REPACK_FN(qweight, in_features, out_features, num_bits)

0 commit comments

Comments
 (0)