Skip to content

Commit 69bf7b0

Browse files
committed
Fix import
Fix precommit Raise error if LoRA is used Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
1 parent e1f3b2d commit 69bf7b0

2 files changed

Lines changed: 13 additions & 14 deletions

File tree

vllm/model_executor/layers/fused_moe/oracle/mxfp8.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ class MxFp8MoeBackend(Enum):
1515
def select_mxfp8_moe_backend(
1616
config: FusedMoEConfig,
1717
) -> MxFp8MoeBackend:
18+
if config.is_lora_enabled:
19+
raise NotImplementedError("LoRA is not supported for MXFP8 MoE.")
20+
1821
AVAILABLE_BACKENDS = [
1922
MxFp8MoeBackend.FLASHINFER_TRTLLM,
2023
]

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import TYPE_CHECKING, Any
66

77
import torch
8-
from flashinfer.fused_moe.core import ActivationType, Fp8QuantizationType
98
from torch.nn.parameter import Parameter
109

1110
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -1710,9 +1709,11 @@ def __init__(
17101709
) -> None:
17111710
super().__init__(moe_config)
17121711
self.quant_config = quant_config
1713-
self.mxfp8_backend = select_mxfp8_moe_backend(self.moe)
17141712
assert self.quant_config.is_checkpoint_mxfp8_serialized
17151713

1714+
# Select MXFP8 MoE backend
1715+
self.mxfp8_backend = select_mxfp8_moe_backend(self.moe)
1716+
17161717
def create_weights(
17171718
self,
17181719
layer: torch.nn.Module,
@@ -1835,16 +1836,6 @@ def _shuffle_weights_for_trtllm(self, layer: torch.nn.Module) -> None:
18351836
is_gated = self.moe.is_act_and_mul
18361837
intermediate_size_factor = 2 if is_gated else 1
18371838

1838-
logger.debug(
1839-
"MXFP8 MoE: activation=%s is_act_and_mul=%s "
1840-
"w13_weight_shape=%s w13_scale_shape=%s w2_weight_shape=%s",
1841-
self.moe.activation,
1842-
is_gated,
1843-
tuple(layer.w13_weight.shape),
1844-
tuple(layer.w13_weight_scale.shape),
1845-
tuple(layer.w2_weight.shape),
1846-
)
1847-
18481839
w13_weight = layer.w13_weight.data
18491840
w13_scale = layer.w13_weight_scale.data
18501841
if is_gated:
@@ -1927,15 +1918,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
19271918
def maybe_make_prepare_finalize(
19281919
self,
19291920
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
1930-
) -> mk.FusedMoEPrepareAndFinalize | None:
1921+
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
19311922
raise ValueError(
19321923
f"{self.__class__.__name__} uses the new modular kernel initialization "
19331924
"logic. This function should not be called."
19341925
)
19351926

19361927
def select_gemm_impl(
19371928
self,
1938-
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1929+
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
19391930
layer: torch.nn.Module,
19401931
) -> mk.FusedMoEExpertsModular:
19411932
raise ValueError(
@@ -1959,6 +1950,11 @@ def apply_monolithic(
19591950
x: torch.Tensor,
19601951
router_logits: torch.Tensor,
19611952
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1953+
from flashinfer.fused_moe.core import (
1954+
ActivationType,
1955+
Fp8QuantizationType,
1956+
)
1957+
19621958
assert self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
19631959

19641960
if layer.enable_eplb:

0 commit comments

Comments
 (0)