@@ -49,75 +49,113 @@ def _try_resolve(candidates: tuple[tuple[str, str], ...]) -> Optional[Callable[.
4949 return None
5050
5151
52- _GEMM_FN : Optional [Callable [..., Any ]] = _try_resolve (_GEMM_CANDIDATES )
52+ def _resolve_gemm (
53+ candidates : tuple [tuple [str , str ], ...]
54+ ) -> tuple [Optional [Callable [..., Any ]], bool ]:
55+ """Resolve the GEMM fn AND whether it consumes the Marlin-repacked layout.
56+
57+ Only vLLM's ``awq_marlin_gemm`` reads the repacked layout; the autoawq
58+ kernels (``gemm_forward_cuda`` / ``awq_gemm``) want the ORIGINAL AWQ packing,
59+ so repacking the weights under them yields garbage. We track this so the
60+ loader's repack step can no-op for the raw-layout kernels.
61+ """
62+ for module_name , attr in candidates :
63+ try :
64+ mod = __import__ (module_name , fromlist = [attr ])
65+ except Exception :
66+ continue
67+ fn = getattr (mod , attr , None )
68+ if fn is not None :
69+ needs_repack = attr == "awq_marlin_gemm"
70+ logger .debug (
71+ "resolved %s.%s for AWQ GEMM (marlin_layout=%s)" ,
72+ module_name , attr , needs_repack ,
73+ )
74+ return fn , needs_repack
75+ return None , False
76+
77+
78+ _GEMM_FN , _GEMM_NEEDS_REPACK = _resolve_gemm (_GEMM_CANDIDATES )
5379_REPACK_FN : Optional [Callable [..., Any ]] = _try_resolve (_REPACK_CANDIDATES )
5480
5581
5682# ── Probe the kernel's call signature once at load time ──────────────────────
57- # All known AWQ GEMM kernels use (x, qweight, qzeros, scales, last_arg) where
58- # last_arg is either split_k_iters (autoawq style) or group_size (vLLM Marlin
59- # style). We probe with tiny tensors here and cache the working call so the
60- # hot forward path never pays try/except overhead.
83+ # AWQ int4 GEMM kernels DISAGREE on the call layout, so we can't assume one:
84+ # * autoawq awq_ext.gemm_forward_cuda : (x, qw, SCALES, ZEROS, split_k_iters)
85+ # * vLLM awq_gemm : (x, qw, ZEROS, SCALES, split_k_iters)
86+ # * vLLM awq_marlin_gemm : (x, qw, ZEROS, SCALES, group_size)
87+ # They differ in BOTH the scales/zeros order AND the trailing int. We probe each
88+ # combination with tiny tensors and cache the first that runs. Because scales is
89+ # fp16 and zeros is int32, the WRONG scales/zeros order hits a kernel dtype check
90+ # and raises — so try/except reliably discriminates the order.
91+ # (An earlier version hard-coded only the vLLM (zeros, scales) order, which
92+ # silently disabled autoawq's awq_ext on every box that had it -> the slow torch
93+ # dequant fallback. That was the bug.)
94+
95+ _SPLIT_K_ITERS = 8 # autoawq's default; tunes K-dim reduction parallelism
96+
97+ # (label, scales_first, last_kind), most-preferred first. ``scales_first``
98+ # selects the autoawq (True) vs vLLM (False) order of the scales/zeros pair.
99+ _GEMM_SIG_CANDIDATES = (
100+ ("autoawq (x,qw,scales,zeros,split_k)" , True , "split_k" ),
101+ ("vllm (x,qw,zeros,scales,split_k)" , False , "split_k" ),
102+ ("vllm (x,qw,zeros,scales,group_size)" , False , "group_size" ),
103+ ("autoawq (x,qw,scales,zeros,group_size)" , True , "group_size" ),
104+ )
61105
62- _SPLIT_K_ITERS = 8 # autoawq's default; tunes K-dim parallelism
106+
107+ def _make_gemm_caller (scales_first : bool , last_kind : str ) -> Callable [..., Any ]:
108+ """Wrap _GEMM_FN so callers invoke it canonically as
109+ ``caller(x_2d, qweight, qzeros, scales, group_size)`` regardless of the
110+ kernel's native scales/zeros order or trailing-int convention."""
111+ def _call (x_2d , qw , qz , sc , group_size ): # noqa: ANN001
112+ last = _SPLIT_K_ITERS if last_kind == "split_k" else group_size
113+ if scales_first :
114+ return _GEMM_FN (x_2d , qw , sc , qz , last )
115+ return _GEMM_FN (x_2d , qw , qz , sc , last )
116+ return _call
63117
64118
65119def _probe_gemm_signature () -> Optional [Callable [..., Any ]]:
66- """Return a zero-argument callable that calls the resolved GEMM fn with the
67- correct arg order, or None if no kernel is available or the probe fails.
68-
69- All known kernels share the layout:
70- fn(x_2d, qweight, qzeros, scales, last_arg)
71- where last_arg is split_k_iters (int) or group_size (int).
72- The old code had scales/qzeros SWAPPED on the first try, causing a
73- RuntimeError on every forward that silently fell through to the slow
74- torch dequant path.
120+ """Return a caller invoking the resolved GEMM fn with the correct arg order,
121+ or None if no kernel is available or every known signature fails the probe.
75122 """
76123 if _GEMM_FN is None :
77124 return None
78125
79126 try :
80127 import torch as _torch
81- # Minimal tensors: in=128 (one group), out=16 (× pack=8 → 128 packed).
82- group_size_probe = 128
83- in_f , out_f = group_size_probe , 16
84128 device = _torch .device ("cuda" if _torch .cuda .is_available () else "cpu" )
85129 if device .type != "cuda" :
86130 return None
87131
88- x_p = _torch .zeros (1 , in_f , dtype = _torch .float16 , device = device )
89- qw_p = _torch .zeros (in_f , out_f , dtype = _torch .int32 , device = device )
90- scales_p = _torch .ones (1 , out_f * 8 , dtype = _torch .float16 , device = device )
91- qzeros_p = _torch .zeros (1 , out_f , dtype = _torch .int32 , device = device )
92-
93- # Try split_k_iters style (autoawq / awq_ext).
94- try :
95- _GEMM_FN (x_p , qw_p , qzeros_p , scales_p , _SPLIT_K_ITERS )
96- logger .debug ("marlin/awq GEMM: using split_k_iters signature" )
97-
98- def _call (x_2d , qw , qz , sc , * _ ): # noqa: ANN001
99- return _GEMM_FN (x_2d , qw , qz , sc , _SPLIT_K_ITERS )
100-
101- return _call
102- except (RuntimeError , TypeError ):
103- pass
104-
105- # Try group_size style (vLLM awq_marlin_gemm).
106- try :
107- _GEMM_FN (x_p , qw_p , qzeros_p , scales_p , group_size_probe )
108- logger .debug ("marlin/awq GEMM: using group_size signature" )
109-
110- def _call (x_2d , qw , qz , sc , group_size ): # noqa: ANN001
111- return _GEMM_FN (x_2d , qw , qz , sc , group_size )
112-
113- return _call
114- except (RuntimeError , TypeError ):
115- pass
132+ # Minimal valid AWQ shapes: K=256 (2 groups of 128), N=256 (=32 × pack 8).
133+ group_size_probe = 128
134+ in_f , out_f = 256 , 32
135+ n_groups = in_f // group_size_probe
136+ x_p = _torch .zeros (1 , in_f , dtype = _torch .float16 , device = device )
137+ qw_p = _torch .zeros (in_f , out_f , dtype = _torch .int32 , device = device )
138+ scales_p = _torch .ones (n_groups , out_f * 8 , dtype = _torch .float16 , device = device )
139+ qzeros_p = _torch .zeros (n_groups , out_f , dtype = _torch .int32 , device = device )
140+
141+ for label , scales_first , last_kind in _GEMM_SIG_CANDIDATES :
142+ caller = _make_gemm_caller (scales_first , last_kind )
143+ try :
144+ out = caller (x_p , qw_p , qzeros_p , scales_p , group_size_probe )
145+ except (RuntimeError , TypeError ):
146+ continue
147+ # Guard against a silently-accepted wrong layout: the output must be
148+ # (M, out_features) and finite.
149+ if out .shape [0 ] != 1 or out .shape [- 1 ] != out_f * 8 \
150+ or not _torch .isfinite (out ).all ():
151+ continue
152+ logger .info ("marlin/awq GEMM: using %s signature" , label )
153+ return caller
116154
117155 logger .warning (
118- "marlin/awq GEMM fn %r: neither split_k_iters nor group_size "
119- "signature worked during probe — disabling kernel. AWQ will use "
120- "ExLlamaV2 or the torch dequant fallback." ,
156+ "marlin/awq GEMM fn %r: no known signature worked during probe "
157+ "(tried autoawq + vLLM orders x split_k/group_size) — disabling "
158+ "kernel. AWQ will use ExLlamaV2 or the torch dequant fallback." ,
121159 _GEMM_FN ,
122160 )
123161 return None
@@ -181,7 +219,10 @@ def awq_marlin_repack(
181219 """Repack an AWQ ``qweight`` into Marlin's layout, if a repack kernel
182220 is available. Falls back to returning the input contiguous if not.
183221 """
184- if _REPACK_FN is None :
222+ # Repack ONLY when the resolved GEMM actually consumes the Marlin layout
223+ # (vLLM awq_marlin_gemm). The autoawq raw-layout kernels must keep the
224+ # ORIGINAL AWQ packing, or the GEMM reads garbage.
225+ if _REPACK_FN is None or not _GEMM_NEEDS_REPACK :
185226 return qweight .contiguous ()
186227 try :
187228 return _REPACK_FN (qweight , in_features , out_features , num_bits )
0 commit comments