@@ -301,7 +301,7 @@ index 00000000..13cf674a
301301+
302302+ return hf_name_param
303303diff --git a/slime/backends/megatron_utils/model_provider.py b/slime/backends/megatron_utils/model_provider.py
304- index 8174c7ac..cdd57524 100644
304+ index 8174c7ac..33fc9a99 100644
305305--- a/slime/backends/megatron_utils/model_provider.py
306306+++ b/slime/backends/megatron_utils/model_provider.py
307307@@ -17,7 +17,7 @@ from megatron.core.transformer.transformer_config import TransformerConfig
@@ -331,23 +331,18 @@ index 8174c7ac..cdd57524 100644
331331 # Support custom model provider path (similar to --custom-rm-path for reward models)
332332 if getattr(args, "custom_model_provider_path", None):
333333
334- @@ -83,11 +85,14 @@ def get_model_provider_func(
335- bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
336- provider = bridge.to_megatron_provider(load_weights=False)
337- # TODO: we should not manually set this...
338- - provider.tensor_model_parallel_size = args.tensor_model_parallel_size
339- - provider.pipeline_model_parallel_size = args.pipeline_model_parallel_size
340- - provider.expert_model_parallel_size = args.expert_model_parallel_size
341- - provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size
342- - provider.sequence_parallel = args.sequence_parallel
334+ @@ -88,6 +90,14 @@ def get_model_provider_func(
335+ provider.expert_model_parallel_size = args.expert_model_parallel_size
336+ provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size
337+ provider.sequence_parallel = args.sequence_parallel
343338+ provider.gradient_accumulation_fusion = args.gradient_accumulation_fusion
344339+ provider.recompute_granularity = args.recompute_granularity
345340+ provider.recompute_method = args.recompute_method
346341+ provider.recompute_num_layers = args.recompute_num_layers
347342+ for key, value in vars(args).items():
348343+ if hasattr(provider, key):
349344+ continue
350- + setattr(provider, key, value)
345+ + setattr(provider, key, value)
351346 provider.finalize()
352347 return provider.provide
353348
0 commit comments