File tree Expand file tree Collapse file tree 2 files changed +8
-8
lines changed
vllm_rbln/model_executor/layers/quantization Expand file tree Collapse file tree 2 files changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -811,6 +811,9 @@ def process_weights_after_loading(self, layer):
811811 max_w13_scales , requires_grad = False
812812 )
813813
814+ if getattr (layer , "_expert_map" , None ) is not None :
815+ layer ._expert_map_list = layer ._expert_map .data .to (dtype = torch .int32 ).tolist ()
816+
814817 def select_gemm_impl (
815818 self ,
816819 prepare_finalize : mk .FusedMoEPrepareAndFinalizeModular ,
@@ -863,10 +866,7 @@ def apply(
863866
864867 expert_map_const = None
865868 if layer .expert_map is not None :
866- expert_map_const = layer .expert_map
867- if expert_map_const .dtype != torch .int32 :
868- expert_map_const = expert_map_const .to (dtype = torch .int32 )
869- expert_map_const = expert_map_const .detach ().clone ()
869+ expert_map_const = torch .tensor (layer ._expert_map_list , dtype = torch .int32 )
870870
871871 tokens_mask = None
872872 use_moe_tokens_mask = envs .VLLM_RBLN_USE_MOE_TOKENS_MASK
Original file line number Diff line number Diff line change @@ -375,6 +375,9 @@ def process_weights_after_loading(self, layer):
375375 layer .register_buffer ("down_proj_scales" , layer .w2_weight_scale .data )
376376 layer .register_buffer ("down_proj_bias" , layer .w2_bias .data )
377377
378+ if getattr (layer , "_expert_map" , None ) is not None :
379+ layer ._expert_map_list = layer ._expert_map .data .to (dtype = torch .int32 ).tolist ()
380+
378381 def select_gemm_impl (
379382 self ,
380383 prepare_finalize : mk .FusedMoEPrepareAndFinalizeModular ,
@@ -407,10 +410,7 @@ def apply(
407410 if layer .activation == MoEActivation .SWIGLUOAI :
408411 expert_map_const = None
409412 if layer .expert_map is not None :
410- assert getattr (layer , "expert_map_const" , None ) is not None
411- expert_map_const = torch .tensor (
412- layer .expert_map_const , dtype = torch .int32
413- )
413+ expert_map_const = torch .tensor (layer ._expert_map_list , dtype = torch .int32 )
414414
415415 tokens_mask = None
416416 use_moe_tokens_mask = envs .VLLM_RBLN_USE_MOE_TOKENS_MASK
You can’t perform that action at this time.
0 commit comments