@@ -112,6 +112,13 @@ def __init__(
112112 # Cost coefficients (probed at server startup) for cost-aware
113113 # tree shape + mode selection. None = degraded mode (defaults).
114114 cost_coefficients : Any = None ,
115+ # torch.compile(mode="reduce-overhead") — captures CUDA graphs +
116+ # fuses pointwise ops (RMSNorm/RoPE/SwiGLU/residual) → removes the
117+ # per-token kernel-launch overhead that caps eager decode. Opt-in
118+ # and EXPERIMENTAL: compilation is lazy (first forward), so a bad
119+ # interaction surfaces at runtime, not here — drop the flag if a
120+ # run errors. Off by default so it can never regress the eager path.
121+ compile_model : bool = False ,
115122 ):
116123 if device is None :
117124 device = default_device ()
@@ -282,6 +289,21 @@ def __init__(
282289 from .flash_attn_ext import install_flash_attention
283290 self ._flash_attn_patched = install_flash_attention (self .model )
284291
292+ # torch.compile LAST, after any model patching. reduce-overhead mode
293+ # uses CUDA graphs + Triton fusion to erase per-token launch overhead
294+ # (the gap between eager decode and the bandwidth ceiling). Lazy: the
295+ # actual compile happens on the first forward, so we can't catch a
296+ # failure here — wrap-time errors are caught; runtime graph-breaks just
297+ # degrade to partial speedup. Drop --compile if a run errors outright.
298+ self ._compiled = False
299+ if compile_model :
300+ try :
301+ self .model = torch .compile (self .model , mode = "reduce-overhead" )
302+ self ._compiled = True
303+ log .info ("torch.compile(reduce-overhead) enabled (experimental)" )
304+ except Exception as e :
305+ log .warning ("torch.compile failed (%s); running eager" , e )
306+
285307 # ------------------------------------------------------------------
286308 # Factory
287309 # ------------------------------------------------------------------
@@ -293,6 +315,7 @@ def from_pretrained(
293315 strict : bool = True ,
294316 streaming_config : Optional ["StreamingConfig" ] = None ,
295317 awq_path : Optional [str ] = None ,
318+ attn_implementation : str = "auto" ,
296319 ** kwargs ,
297320 ) -> "InferenceEngine" :
298321 """
@@ -309,7 +332,21 @@ def from_pretrained(
309332 the config. The rest of KVBoost (chunk-reuse, FlashAttn)
310333 is untouched.
311334 awq_path: Optional path hint forwarded to the streaming loader.
312- **kwargs: Passed to InferenceEngine.__init__().
335+ attn_implementation: Attention backend for the resident path.
336+ ``"auto"`` (default) tries ``flash_attention_2`` (FA2 —
337+ Ampere+ wheel; faster, lower-memory prefill → better TTFT)
338+ and silently falls back to ``"sdpa"`` if FA2 isn't
339+ installed/supported. Pass ``"flash_attention_2"`` to
340+ require it (raises if unavailable), or ``"sdpa"`` /
341+ ``"eager"`` to force a backend. Ignored on the streaming
342+ path. To load a **quantized** checkpoint (AWQ/GPTQ →
343+ Marlin int4 GEMM on Ampere, ~4× less weight bandwidth →
344+ higher decode tok/s), just pass a quantized ``model_name``;
345+ transformers reads its quantization_config and picks the
346+ kernel automatically — the engine already detects and
347+ leaves quantized/offloaded weights in place.
348+ **kwargs: Passed to InferenceEngine.__init__() (e.g.
349+ ``compile_model=True`` for torch.compile reduce-overhead).
313350 """
314351 log .info ("Loading model %s ..." , model_name )
315352 tokenizer = AutoTokenizer .from_pretrained (model_name )
@@ -326,11 +363,31 @@ def from_pretrained(
326363 dtype = torch .float16 ,
327364 )
328365 else :
329- model = AutoModelForCausalLM .from_pretrained (
330- model_name ,
331- torch_dtype = torch .float16 ,
332- low_cpu_mem_usage = True ,
333- )
366+ load_kwargs = dict (torch_dtype = torch .float16 , low_cpu_mem_usage = True )
367+ impl = attn_implementation
368+ if impl in ("auto" , "flash_attention_2" ):
369+ try :
370+ model = AutoModelForCausalLM .from_pretrained (
371+ model_name ,
372+ attn_implementation = "flash_attention_2" ,
373+ ** load_kwargs ,
374+ )
375+ log .info ("Attention backend: flash_attention_2" )
376+ except Exception as e :
377+ if impl == "flash_attention_2" :
378+ raise # caller explicitly required FA2 — don't mask it
379+ log .info (
380+ "flash_attention_2 unavailable (%s); using sdpa" , e
381+ )
382+ model = AutoModelForCausalLM .from_pretrained (
383+ model_name , attn_implementation = "sdpa" , ** load_kwargs
384+ )
385+ log .info ("Attention backend: sdpa" )
386+ else :
387+ model = AutoModelForCausalLM .from_pretrained (
388+ model_name , attn_implementation = impl , ** load_kwargs
389+ )
390+ log .info ("Attention backend: %s" , impl )
334391 model .eval ()
335392
336393 check_model_compatibility (model , strict = strict )
@@ -1110,16 +1167,25 @@ def _decode_with_kv(
11101167
11111168 # ----- autoregressive decode ------------------------------------
11121169 cur_pos = cached_len + len (live_ids )
1170+ # Pre-allocate the (1,1) decode input buffers ONCE and write in place
1171+ # each step, instead of allocating two fresh device tensors per token.
1172+ # Removes a per-token alloc + H2D churn, and — critically — gives
1173+ # torch.compile / CUDA-graph capture stable input tensors to graph
1174+ # against (a graph needs fixed input storage; a new tensor per step
1175+ # forces a recapture/recompile). Correctness is identical: the model
1176+ # reads these tensors, it never retains them across steps.
1177+ input_buf = torch .empty ((1 , 1 ), dtype = torch .long , device = self .device )
1178+ pos_buf = torch .empty ((1 , 1 ), dtype = torch .long , device = self .device )
11131179 while not goto_done and len (generated ) < max_new_tokens :
11141180 if generated [- 1 ] == self .tokenizer .eos_token_id :
11151181 break
1116- cur_ids = torch . tensor ([[ generated [- 1 ]]], dtype = torch . long , device = self . device )
1117- pos_ids = torch . tensor ([[ cur_pos ]], dtype = torch . long , device = self . device )
1182+ input_buf [ 0 , 0 ] = generated [- 1 ]
1183+ pos_buf [ 0 , 0 ] = cur_pos
11181184 with torch .no_grad (), last_logit_only (self .model ):
11191185 out = self .model (
1120- input_ids = cur_ids ,
1186+ input_ids = input_buf ,
11211187 past_key_values = self ._as_cache (past_kv ),
1122- position_ids = pos_ids ,
1188+ position_ids = pos_buf ,
11231189 use_cache = True ,
11241190 )
11251191 past_kv = self ._normalize_past_kv (out .past_key_values )
0 commit comments