Skip to content

Commit c6be251

Browse files
authored
[NPU] RL update_weights_from_disk/ tensor /distributed (#26717)
1 parent cd6efcb commit c6be251

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

  • python/sglang/srt/layers/moe/fused_moe_triton

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
get_bool_env_var,
7171
is_cpu,
7272
is_hip,
73+
is_npu,
7374
print_info_once,
7475
round_up,
7576
)
@@ -78,6 +79,7 @@
7879
_is_hip = is_hip()
7980
_is_cpu_amx_available = cpu_has_amx_support()
8081
_is_cpu = is_cpu()
82+
_is_npu = is_npu()
8183
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
8284

8385

@@ -778,6 +780,13 @@ def _weight_loader_impl(
778780
# expert weights into block layout. During weight update, we must restore
779781
# canonical load-time shapes before copying checkpoint tensors.
780782
if isinstance(method, UnquantizedFusedMoEMethod):
783+
if _is_npu:
784+
if weight_name.endswith(".experts.w2_weight"):
785+
if param.data.shape[1] != loaded_weight.shape[0]:
786+
param.data = param.data.transpose(1, 2).contiguous()
787+
if weight_name.endswith(".experts.w13_weight"):
788+
if param.data.shape[2] != loaded_weight.shape[1]:
789+
param.data = param.data.transpose(1, 2).contiguous()
781790
method.maybe_restore_flashinfer_trtllm_bf16_weight_shape_for_load(
782791
layer=self,
783792
param=param,

0 commit comments

Comments
 (0)