Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -77,6 +77,9 @@ def __init__(
self.deep_ep_buffer = buffer_pool.get_buffer(mapping)
self.deep_ep_buffer.reserve(hidden_size, weight_dtype)

# Invalid token expert ID: TRTLLM-gen kernels only support -1 for invalid tokens.
self.invalid_token_expert_id = -1

def destroy(self):
"""Release the DeepEP buffer to prevent deadlock/hang.

Expand Down Expand Up @@ -233,6 +236,19 @@ def dispatch(
"padded": padded,
}

if kwargs.get("enable_sanitize_expert_ids", False) and token_selected_slots.numel() > 0:
# After dispatch, non-local expert slots are replaced with invalid_token_expert_id.
# Some renormalize kernel but not all yet might do this sanitization,
# but we want to make sure it is always done for non-local tokens to avoid potential issues.
slot_start = self.expert_size_per_partition * self.ep_rank
slot_end = slot_start + self.expert_size_per_partition
non_local_mask = (token_selected_slots < slot_start) | (
token_selected_slots >= slot_end
)
token_selected_slots = token_selected_slots.masked_fill(
non_local_mask, self.invalid_token_expert_id
)

# Restore token_final_scales to original dtype for downstream consumers
if (
token_final_scales is not None
Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,10 @@ def _forward_chunk_impl(
if self.enable_dummy_allreduce:
self.dummy_allreduce()

dispatch_kwargs = dict(eplb_dispatch_kwargs)
if isinstance(self.comm, DeepEP) and isinstance(self.backend, TRTLLMGenFusedMoE):
dispatch_kwargs["enable_sanitize_expert_ids"] = True

if supports_post_quant:
# ===== Post-quant flow: Quantize → Dispatch =====

Expand All @@ -724,7 +728,6 @@ def _forward_chunk_impl(
# Step 4b: Dispatch AFTER quantization
# Get pre_quant_scale for W4AFP8 if available (only DeepEPLowLatency needs it)
# Other strategies will ignore this via **kwargs, so it's safe to pass unconditionally
dispatch_kwargs = dict(eplb_dispatch_kwargs)
if hasattr(self, "quant_scales") and self.quant_scales is not None:
if hasattr(self.quant_scales, "pre_quant_scale_1"):
dispatch_kwargs["pre_quant_scale"] = self.quant_scales.pre_quant_scale_1
Expand All @@ -751,6 +754,7 @@ def _forward_chunk_impl(
token_final_scales=token_final_scales,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
**dispatch_kwargs,
)

# Step 4b: Quantization AFTER dispatch
Expand Down
16 changes: 0 additions & 16 deletions tests/unittest/_torch/modules/moe/moe_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,22 +352,6 @@ def should_skip_trtllm(
f"Single-GPU tests pass; issue is in the kernel runner under EP."
)

# Issue: NVFP4 with large model configs crashes with CUDA illegal memory
# access in DeepEP mode (deep_ep.cpp:86).
# Verified: e60_k4_h2048_i1408 passes, e256_k8_h7168_i2048 crashes.
# The crash kills the entire pytest process, blocking all subsequent tests.
if (
quant_algo == QuantAlgo.NVFP4
and num_experts >= 256
and model_config.hidden_size >= 7168
):
return (
f"[Potential Bug] TRTLLMGenFusedMoE NVFP4 with large model "
f"(num_experts={num_experts}, hidden_size={model_config.hidden_size}) "
f"crashes with CUDA illegal memory access in DeepEP mode "
f"(comm={comm_method}). Smaller configs pass."
)

# TP per-shard alignment: when moe_tp_size > 1, intermediate_size is sharded.
# MXFP4 variants (W4A16_MXFP4, W4A8_MXFP4_MXFP8) auto-pad to 128 alignment,
# but other quants (FP8_BLOCK_SCALES, NVFP4, W4A8_NVFP4_FP8) crash:
Expand Down
Loading