@@ -294,6 +294,12 @@ def __init__(self, moe: FusedMoEConfig):
294294 # Initialized in process_weights_after_loading for CUTLASS/SM90 backends
295295 self .moe_kernel : mk .FusedMoEKernel | None = None
296296
297+ @property
298+ def skip_forward_padding (self ) -> bool :
299+ # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
300+ # so can skip the padding in the forward before applying the moe method
301+ return self .mxfp4_backend == Mxfp4Backend .SM100_FI_MXFP4_MXFP8_TRTLLM
302+
297303 def create_weights (
298304 self ,
299305 layer : torch .nn .Module ,
@@ -1130,9 +1136,17 @@ def apply_monolithic(
11301136 elif self .mxfp4_backend == Mxfp4Backend .SM100_FI_MXFP4_MXFP8_TRTLLM :
11311137 from flashinfer import mxfp8_quantize
11321138
1133- x_quant , x_scale = mxfp8_quantize (x , False ) # to mxfp8
1139+ # x_quant is padded in hidden dimension with alignment=256
1140+ x_quant , x_scale = mxfp8_quantize (
1141+ x ,
1142+ is_sf_swizzled_layout = False ,
1143+ alignment = 256 ,
1144+ )
11341145 x_scale = x_scale .view (torch .float8_e4m3fn ).reshape (* x .shape [:- 1 ], - 1 )
11351146
1147+ # output with original unpadded hidden size
1148+ output = torch .empty_like (x )
1149+
11361150 trtllm_gen_output = trtllm_fp4_block_scale_moe (
11371151 routing_logits = router_logits .to (torch .bfloat16 ),
11381152 routing_bias = None ,
@@ -1161,6 +1175,7 @@ def apply_monolithic(
11611175 routing_method_type = 1 if layer .renormalize else 0 ,
11621176 do_finalize = True ,
11631177 tune_max_num_tokens = max (self .max_capture_size , 1 ),
1178+ output = output ,
11641179 )[0 ]
11651180 return trtllm_gen_output
11661181 elif self .mxfp4_backend == Mxfp4Backend .CK :
0 commit comments