Skip to content

Commit 191cb99

Browse files
committed
fix torch export in moe
1 parent 784e93f commit 191cb99

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

vllm_rbln/model_executor/layers/quantization/fp8.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,9 @@ def process_weights_after_loading(self, layer):
811811
max_w13_scales, requires_grad=False
812812
)
813813

814+
if getattr(layer, "_expert_map", None) is not None:
815+
layer._expert_map_list = layer._expert_map.data.to(dtype=torch.int32).tolist()
816+
814817
def select_gemm_impl(
815818
self,
816819
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
@@ -863,10 +866,7 @@ def apply(
863866

864867
expert_map_const = None
865868
if layer.expert_map is not None:
866-
expert_map_const = layer.expert_map
867-
if expert_map_const.dtype != torch.int32:
868-
expert_map_const = expert_map_const.to(dtype=torch.int32)
869-
expert_map_const = expert_map_const.detach().clone()
869+
expert_map_const = torch.tensor(layer._expert_map_list, dtype=torch.int32)
870870

871871
tokens_mask = None
872872
use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK

vllm_rbln/model_executor/layers/quantization/mxfp4.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ def process_weights_after_loading(self, layer):
375375
layer.register_buffer("down_proj_scales", layer.w2_weight_scale.data)
376376
layer.register_buffer("down_proj_bias", layer.w2_bias.data)
377377

378+
if getattr(layer, "_expert_map", None) is not None:
379+
layer._expert_map_list = layer._expert_map.data.to(dtype=torch.int32).tolist()
380+
378381
def select_gemm_impl(
379382
self,
380383
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
@@ -407,10 +410,7 @@ def apply(
407410
if layer.activation == MoEActivation.SWIGLUOAI:
408411
expert_map_const = None
409412
if layer.expert_map is not None:
410-
assert getattr(layer, "expert_map_const", None) is not None
411-
expert_map_const = torch.tensor(
412-
layer.expert_map_const, dtype=torch.int32
413-
)
413+
expert_map_const = torch.tensor(layer._expert_map_list, dtype=torch.int32)
414414

415415
tokens_mask = None
416416
use_moe_tokens_mask = envs.VLLM_RBLN_USE_MOE_TOKENS_MASK

0 commit comments

Comments
 (0)