Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions cpp/tensorrt_llm/thop/fp8Op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,12 @@ std::tuple<Tensor, Tensor> e4m3_quantize_helper(Tensor input, at::optional<Tenso

if (scales.has_value())
{
// static quantization will use float scales by default.
scales_ = scales.value().clone();
CHECK_TH_CUDA(scales_);
CHECK_TYPE(scales_, torch::kFloat32);
e4m3_static_quantize(input, quantized_input, scales_, stream, quantize_mode);
// static quantization will use float scales by default and output scales_ will be ignored.
scales_
= torch::empty_like(scales.value(), torch::dtype(input.dtype()).device(torch::kCUDA).requires_grad(false));
CHECK_TH_CUDA(scales.value());
CHECK_TYPE(scales.value(), torch::kFloat32);
e4m3_static_quantize(input, quantized_input, scales.value(), stream, quantize_mode);
}
else
{
Expand Down
4 changes: 2 additions & 2 deletions jenkins/L0_Test.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -1772,8 +1772,8 @@ def createKubernetesPodConfig(image, type, arch = "amd64", gpuCount = 1, perfMod

// Austin FlexCache looks slow and unstable recently. Remove gh200 temporarily.
// That means gh200 nodes will use the default Blossom data scratch.
if (type.contains("6000d")) {
// rtx-pro-6000d and gh200 nodes are located in Austin DC, we use the FlexCache to speed up the data access.
if (type.contains("6000d") || type.contains("rtx-5080")) {
// rtx-pro-6000d, gh200 and rtx-5080 nodes are located in Austin DC, we use the FlexCache to speed up the data access.
llmModelVolume = """
- name: scratch-trt-llm-data
nfs:
Expand Down
27 changes: 27 additions & 0 deletions tensorrt_llm/_torch/models/modeling_nemotron_nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,8 @@ def __init__(self, model_config: ModelConfig):

llm_model_config = copy.deepcopy(model_config)
llm_model_config.pretrained_config = llm_model_config.pretrained_config.llm_config
self._update_config_for_quantization(llm_model_config)

self.llm = AutoModelForCausalLM.from_config(llm_model_config)

self.vocab_size = llm_model_config.pretrained_config.vocab_size
Expand Down Expand Up @@ -1467,6 +1469,31 @@ def forward(
logger.debug(f"output shape: {output_prob.shape}")
return output_prob

@staticmethod
def _update_config_for_quantization(llm_model_config: ModelConfig) -> None:
# Strip the VL wrapper prefix from exclude_modules and
# quant_config_dict so patterns match the inner LLM's module names
# (e.g. "language_model.backbone.layers.0.mixer.conv1d" becomes
# "backbone.layers.0.mixer.conv1d").
_LM_PREFIX = "language_model."
if llm_model_config.quant_config.exclude_modules is not None:
llm_model_config.quant_config.exclude_modules = [
m[len(_LM_PREFIX) :] if m.startswith(_LM_PREFIX) else m
for m in llm_model_config.quant_config.exclude_modules
]
if llm_model_config.quant_config_dict is not None:
# NOTE: without `_frozen` toggling, `ModelConfig` cannot have its attributes
# modified.
old_frozen = llm_model_config._frozen
llm_model_config._frozen = False
try:
llm_model_config.quant_config_dict = {
k[len(_LM_PREFIX) :] if k.startswith(_LM_PREFIX) else k: v
for k, v in llm_model_config.quant_config_dict.items()
}
finally:
llm_model_config._frozen = old_frozen


def _rearrange_img(x: torch.Tensor, patch_size: int) -> torch.Tensor:
py = x.shape[-2] // patch_size
Expand Down
9 changes: 9 additions & 0 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,15 @@ def convert_qkv(self, q, k, v):
return q, k, v

def _use_quantize_output(self):
# If o_proj can't consume, then no need to quantize the output to nvfp4
if hasattr(self.attn, 'has_nvfp4'
) and self.attn.has_nvfp4 and not self.o_proj.has_nvfp4:
return False
# If no quant is applied, no need to quantize the output
if self.quant_config is not None and not self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
return False

has_awq_pre_quant_scale = hasattr(
self.o_proj,
'pre_quant_scale') and self.o_proj.pre_quant_scale is not None
Expand Down
7 changes: 0 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,6 @@ class Diff:
generation_logits_list: list[torch.Tensor] = field(default_factory=list)
reset_log_probs: tuple[list[TokenLogprobs],
list[float] | None] | None = None
log_probs_list: list[tuple[list[TokenLogprobs], list[float]
| None]] = field(default_factory=list)
mm_embeddings: list[dict[str, Any] | None] = None
mrope_position_ids: dict[str, Any] | None = None
mrope_position_deltas: dict[str, Any] | None = None
Expand Down Expand Up @@ -349,9 +347,6 @@ def apply_diff(self, diff: Diff):
self._generation_logits.append(generation_logits)
if diff.reset_log_probs is not None:
self._log_probs.set_log_probs(*diff.reset_log_probs)
if len(diff.log_probs_list) > 0:
for log_probs, cum_log_probs in diff.log_probs_list:
self._log_probs.append(log_probs, cum_log_probs)
if diff.mm_embeddings is not None:
self._mm_embeddings = diff.mm_embeddings
if diff.mrope_position_ids is not None:
Expand Down Expand Up @@ -386,7 +381,6 @@ def append_log_probs(self,
cum_log_probs: Optional[list[float]] = None):
if self._log_probs:
self._log_probs.append(log_probs, cum_log_probs)
self.diff.log_probs_list.append((log_probs, cum_log_probs))

def append_mm_embeddings(self, mm_embeddings: torch.Tensor,
multimodal_lengths: List[int]):
Expand Down Expand Up @@ -447,7 +441,6 @@ def set_log_probs(self, log_probs: list[TokenLogprobs],
if self._log_probs:
self._log_probs.set_log_probs(log_probs, cum_log_probs)
self.diff.reset_log_probs = (log_probs, cum_log_probs)
self.diff.log_probs_list.clear()

@property
def context_logits(self) -> torch.Tensor | None:
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mooncake_kvcache-90] SKIP (h
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutlass-auto] SKIP (https://nvbugs/5838211)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] SKIP (https://nvbugs/5838211)
full:A10/unittest/kv_cache_manager_v2_tests/ SKIP (https://nvbugs/5841954)
unittest/_torch/modeling/test_modeling_nemotron_h.py::test_nemotron_h_cuda_graph_overlap_scheduler SKIP (https://nvbugs/5843316)
examples/test_mistral.py::test_mistral_with_bf16_lora_torch[mistral-7b-v0.1] SKIP (https://nvbugs/5846178)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5846024)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[disable_skip_indexer] SKIP (https://nvbugs/5859886)
Expand Down
33 changes: 33 additions & 0 deletions tests/unittest/_torch/sampler/test_logits_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,3 +759,36 @@ def test_logprobs_match_hf_tp2():
print(f"Diff: {(trtllm_logprobs - hf_logprobs).abs()}")

torch.testing.assert_close(trtllm_logprobs, hf_logprobs, atol=0.15, rtol=0)


@pytest.mark.gpu2
def test_logprobs_pp2():
"""Test that logprobs count matches generated token count with PP=2.

Regression test for https://github.com/NVIDIA/TensorRT-LLM/issues/12444
Without the fix, logprobs length = 2N-1 instead of N due to duplication
in the PP ring broadcast diff mechanism.
"""
model_path = os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0")
max_tokens = 16
llm = LLM(
model=model_path,
pipeline_parallel_size=2,
max_batch_size=1,
max_num_tokens=128,
max_seq_len=256,
)

sampling_params = SamplingParams(
max_tokens=max_tokens,
logprobs=5,
)

output = list(llm.generate(["The future of the AI is"], sampling_params=sampling_params))[0]

num_tokens = len(output.outputs[0].token_ids)
num_logprobs = len(output.outputs[0].logprobs)
assert num_logprobs == num_tokens, (
f"logprobs length {num_logprobs} != generated tokens {num_tokens} "
f"(expected 1:1 ratio, got {num_logprobs / num_tokens:.2f}:1)"
)
Loading