@@ -238,13 +238,17 @@ def unquantized_fused_moe_method_rbln(
238238 return final_hidden_states .reshape (orig_shape )
239239
240240
241- def _get_tokens_mask ( ):
242- num_tokens = \
241+ def get_tokens_mask ( num_tokens : int , left = 1.0 , right = float ( '-inf' ) ):
242+ num_tokens_across_dp = \
243243 get_forward_context ().dp_metadata .num_tokens_across_dp_cpu
244- num_tokens = num_tokens .unsqueeze (1 )
245- max_pad = get_forward_context ().dp_metadata .max_pads_across_dp
244+ num_tokens_across_dp = num_tokens_across_dp .unsqueeze (1 )
245+ if num_tokens_across_dp .size (0 ) == 1 :
246+ max_pad = num_tokens
247+ else :
248+ max_pad = get_forward_context ().dp_metadata .max_pads_across_dp
246249 pos = torch .arange (max_pad , dtype = torch .int32 ).unsqueeze (0 ) # [1, max_pad]
247- tokens_mask = torch .where (pos < num_tokens , 1.0 , 0.0 ) # [dp_size, max_pad]
250+ tokens_mask = torch .where (pos < num_tokens_across_dp , left ,
251+ right ) # [dp_size, max_pad]
248252 tokens_mask = tokens_mask .reshape (- 1 , 1 ) #[dp_size * max_pad, 1]
249253 return tokens_mask
250254
@@ -268,7 +272,7 @@ def get_masked_routing_weights(router_logits, top_k, renormalize, expert_map):
268272
269273 use_moe_tokens_mask = envs .VLLM_RBLN_USE_MOE_TOKENS_MASK
270274 if use_moe_tokens_mask :
271- tokens_mask = _get_tokens_mask ( )
275+ tokens_mask = get_tokens_mask ( router_logits . shape [ 0 ], 1.0 , 0.0 )
272276 selected_weights = selected_weights * tokens_mask
273277
274278 n_expert = router_logits .shape [1 ]
@@ -393,6 +397,11 @@ def unquantized_fused_optimize_moe_method_custom(
393397 expert_map_list = expert_map .tolist ()
394398 expert_map_const = torch .tensor (expert_map_list , dtype = torch .int32 )
395399
400+ use_moe_tokens_mask = envs .VLLM_RBLN_USE_MOE_TOKENS_MASK
401+ if use_moe_tokens_mask :
402+ tokens_mask = get_tokens_mask (num_tokens )
403+ router_logits = router_logits * tokens_mask
404+
396405 # optimum-rbln/src/optimum/rbln/transformers/models/qwen3_moe/
397406 # qwen3_moe_architecture.py
398407 final_hidden_states = torch .ops .rbln_custom_ops .custom_moe_glu (
0 commit comments