Skip to content

Commit 562c8da

Browse files
committed
fix: add parallel params back
1 parent ab90de3 commit 562c8da

File tree

1 file changed

+6
-11
lines changed
  • docker/npu_patch/qwen3_vl_8b_multi_turn_grpo

1 file changed

+6
-11
lines changed

docker/npu_patch/qwen3_vl_8b_multi_turn_grpo/slime.patch

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ index 00000000..13cf674a
301301
+
302302
+ return hf_name_param
303303
diff --git a/slime/backends/megatron_utils/model_provider.py b/slime/backends/megatron_utils/model_provider.py
304-
index 8174c7ac..cdd57524 100644
304+
index 8174c7ac..33fc9a99 100644
305305
--- a/slime/backends/megatron_utils/model_provider.py
306306
+++ b/slime/backends/megatron_utils/model_provider.py
307307
@@ -17,7 +17,7 @@ from megatron.core.transformer.transformer_config import TransformerConfig
@@ -331,23 +331,18 @@ index 8174c7ac..cdd57524 100644
331331
# Support custom model provider path (similar to --custom-rm-path for reward models)
332332
if getattr(args, "custom_model_provider_path", None):
333333

334-
@@ -83,11 +85,14 @@ def get_model_provider_func(
335-
bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
336-
provider = bridge.to_megatron_provider(load_weights=False)
337-
# TODO: we should not manually set this...
338-
- provider.tensor_model_parallel_size = args.tensor_model_parallel_size
339-
- provider.pipeline_model_parallel_size = args.pipeline_model_parallel_size
340-
- provider.expert_model_parallel_size = args.expert_model_parallel_size
341-
- provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size
342-
- provider.sequence_parallel = args.sequence_parallel
334+
@@ -88,6 +90,14 @@ def get_model_provider_func(
335+
provider.expert_model_parallel_size = args.expert_model_parallel_size
336+
provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size
337+
provider.sequence_parallel = args.sequence_parallel
343338
+ provider.gradient_accumulation_fusion = args.gradient_accumulation_fusion
344339
+ provider.recompute_granularity = args.recompute_granularity
345340
+ provider.recompute_method = args.recompute_method
346341
+ provider.recompute_num_layers = args.recompute_num_layers
347342
+ for key, value in vars(args).items():
348343
+ if hasattr(provider, key):
349344
+ continue
350-
+ setattr(provider, key, value)
345+
+ setattr(provider, key, value)
351346
provider.finalize()
352347
return provider.provide
353348

0 commit comments

Comments
 (0)