Skip to content

Commit ce80c14

Browse files
authored
[None][fix] skip inference_mode() when torch.compile=True for gemma3 fp8 (NVIDIA#12367)
Signed-off-by: Anurag Mukkara <134339030+amukkara@users.noreply.github.com>
1 parent 79d2e37 commit ce80c14

File tree

6 files changed

+30
-8
lines changed

6 files changed

+30
-8
lines changed

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..modules.gated_mlp import GatedMLP
2323
from ..modules.linear import TensorParallelMode
2424
from ..modules.rms_norm import RMSNorm
25+
from ..utils import inference_mode_unless_compiling
2526
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
2627
register_auto_model)
2728

@@ -47,7 +48,7 @@ def __init__(
4748
)
4849
self.embed_scale = torch.sqrt(torch.tensor(hidden_size)).to(self.dtype)
4950

50-
@torch.inference_mode()
51+
@inference_mode_unless_compiling
5152
def forward(self, input_ids):
5253
return super().forward(input_ids) * self.embed_scale
5354

@@ -90,7 +91,7 @@ def __init__(
9091
q_scaling=q_scaling,
9192
)
9293

93-
@torch.inference_mode()
94+
@inference_mode_unless_compiling
9495
def forward(
9596
self,
9697
position_ids: Optional[torch.IntTensor],
@@ -163,7 +164,7 @@ def __init__(
163164
eps=config.rms_norm_eps,
164165
dtype=config.torch_dtype)
165166

166-
@torch.inference_mode()
167+
@inference_mode_unless_compiling
167168
def forward(
168169
self,
169170
position_ids: torch.IntTensor,
@@ -222,7 +223,7 @@ def __init__(self, model_config: ModelConfig[Gemma3TextConfig]):
222223
eps=config.pretrained_config.rms_norm_eps,
223224
dtype=config.pretrained_config.torch_dtype)
224225

225-
@torch.inference_mode()
226+
@inference_mode_unless_compiling
226227
def forward(
227228
self,
228229
attn_metadata: AttentionMetadata,
@@ -392,7 +393,7 @@ def get_flashinfer_attention_mask(
392393
context_mask_list.append(mask_i.flatten())
393394
return torch.cat(context_mask_list, dim=0).contiguous()
394395

395-
@torch.inference_mode()
396+
@inference_mode_unless_compiling
396397
def forward(
397398
self,
398399
attn_metadata: AttentionMetadata,

tensorrt_llm/_torch/modules/qk_norm_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def k_l2norm():
228228
self.ln_events[0],
229229
self.ln_events[1],
230230
self.aux_stream,
231+
disable_on_compile=True,
231232
)
232233

233234
return q, k

tensorrt_llm/_torch/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import functools
23
import os
34
import threading
45
from dataclasses import dataclass
@@ -424,6 +425,19 @@ def wrapper(*args, **kwargs):
424425
return decorator(func) if func else decorator
425426

426427

428+
# This decorator selectively disables inference_mode() to avoid conflicts with torch.dynamo tracing.
429+
def inference_mode_unless_compiling(func):
430+
431+
@functools.wraps(func)
432+
def wrapper(*args, **kwargs):
433+
if torch.compiler.is_compiling():
434+
return func(*args, **kwargs)
435+
with torch.inference_mode():
436+
return func(*args, **kwargs)
437+
438+
return wrapper
439+
440+
427441
def split(x: torch.Tensor,
428442
tp_size: int,
429443
idx: int,

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,14 +1317,17 @@ def test_auto_dtype(self):
13171317
task = MMLU(self.MODEL_NAME)
13181318
task.evaluate(llm)
13191319

1320-
def test_fp8_prequantized(self):
1320+
@parametrize_with_ids("torch_compile", [False, True])
1321+
def test_fp8_prequantized(self, torch_compile):
13211322
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
13221323
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
13231324
enable_partial_reuse=False,
13241325
dtype="fp8")
1326+
torch_compile_config = _get_default_torch_compile_config(torch_compile)
13251327
prequantized_model_path = f"{llm_models_root()}/gemma/gemma-3-1b-it-fp8/"
13261328
with LLM(prequantized_model_path,
1327-
kv_cache_config=kv_cache_config) as llm:
1329+
kv_cache_config=kv_cache_config,
1330+
torch_compile_config=torch_compile_config) as llm:
13281331
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
13291332
task = CnnDailymail(self.MODEL_NAME)
13301333
task.evaluate(llm)

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_nvfp4_tp2
272272
accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_auto_dtype
273273
accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8
274274
accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
275+
accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized[torch_compile=False]
276+
accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized[torch_compile=True]
275277
accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_vswa_reuse
276278
accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_guided_decoding_vswa_reuse[xgrammar]
277279
accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_fp8_prequantized

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,8 @@ l0_h100:
330330
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
331331
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
332332
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0]
333-
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized
333+
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized[torch_compile=False]
334+
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized[torch_compile=True]
334335
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_fp8_prequantized
335336
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype TIMEOUT (90)
336337
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype

0 commit comments

Comments
 (0)