55from typing import TYPE_CHECKING , Any
66
77import torch
8- from flashinfer .fused_moe .core import ActivationType , Fp8QuantizationType
98from torch .nn .parameter import Parameter
109
1110import vllm .model_executor .layers .fused_moe .modular_kernel as mk
@@ -1710,9 +1709,11 @@ def __init__(
17101709 ) -> None :
17111710 super ().__init__ (moe_config )
17121711 self .quant_config = quant_config
1713- self .mxfp8_backend = select_mxfp8_moe_backend (self .moe )
17141712 assert self .quant_config .is_checkpoint_mxfp8_serialized
17151713
1714+ # Select MXFP8 MoE backend
1715+ self .mxfp8_backend = select_mxfp8_moe_backend (self .moe )
1716+
17161717 def create_weights (
17171718 self ,
17181719 layer : torch .nn .Module ,
@@ -1835,16 +1836,6 @@ def _shuffle_weights_for_trtllm(self, layer: torch.nn.Module) -> None:
18351836 is_gated = self .moe .is_act_and_mul
18361837 intermediate_size_factor = 2 if is_gated else 1
18371838
1838- logger .debug (
1839- "MXFP8 MoE: activation=%s is_act_and_mul=%s "
1840- "w13_weight_shape=%s w13_scale_shape=%s w2_weight_shape=%s" ,
1841- self .moe .activation ,
1842- is_gated ,
1843- tuple (layer .w13_weight .shape ),
1844- tuple (layer .w13_weight_scale .shape ),
1845- tuple (layer .w2_weight .shape ),
1846- )
1847-
18481839 w13_weight = layer .w13_weight .data
18491840 w13_scale = layer .w13_weight_scale .data
18501841 if is_gated :
@@ -1927,15 +1918,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
19271918 def maybe_make_prepare_finalize (
19281919 self ,
19291920 routing_tables : tuple [torch .Tensor , torch .Tensor , torch .Tensor ] | None = None ,
1930- ) -> mk .FusedMoEPrepareAndFinalize | None :
1921+ ) -> mk .FusedMoEPrepareAndFinalizeModular | None :
19311922 raise ValueError (
19321923 f"{ self .__class__ .__name__ } uses the new modular kernel initialization "
19331924 "logic. This function should not be called."
19341925 )
19351926
19361927 def select_gemm_impl (
19371928 self ,
1938- prepare_finalize : mk .FusedMoEPrepareAndFinalize ,
1929+ prepare_finalize : mk .FusedMoEPrepareAndFinalizeModular ,
19391930 layer : torch .nn .Module ,
19401931 ) -> mk .FusedMoEExpertsModular :
19411932 raise ValueError (
@@ -1959,6 +1950,11 @@ def apply_monolithic(
19591950 x : torch .Tensor ,
19601951 router_logits : torch .Tensor ,
19611952 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
1953+ from flashinfer .fused_moe .core import (
1954+ ActivationType ,
1955+ Fp8QuantizationType ,
1956+ )
1957+
19621958 assert self .mxfp8_backend == MxFp8MoeBackend .FLASHINFER_TRTLLM
19631959
19641960 if layer .enable_eplb :
0 commit comments