Skip to content
Open
6 changes: 6 additions & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ actor_rollout_ref:
- extra
load_contents: ${.save_contents}
async_save: false
mbridge_config:
memory_efficient: true
distributed_filesystem: true
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
profiler:
_target_: verl.utils.profiler.ProfilerConfig
Expand Down Expand Up @@ -491,6 +494,9 @@ critic:
- extra
load_contents: ${.save_contents}
async_save: false
mbridge_config:
memory_efficient: true
distributed_filesystem: true
profiler:
_target_: verl.utils.profiler.ProfilerConfig
tool: ${oc.select:global_profiler.tool,null}
Expand Down
6 changes: 6 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ actor_rollout_ref:
- extra
load_contents: ${.save_contents}
async_save: false
mbridge_config:
memory_efficient: true
distributed_filesystem: true
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
profiler:
_target_: verl.utils.profiler.ProfilerConfig
Expand Down Expand Up @@ -425,6 +428,9 @@ critic:
- extra
load_contents: ${.save_contents}
async_save: false
mbridge_config:
memory_efficient: true
distributed_filesystem: true
profiler:
_target_: verl.utils.profiler.ProfilerConfig
tool: ${oc.select:global_profiler.tool,null}
Expand Down
5 changes: 5 additions & 0 deletions verl/trainer/config/actor/actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ checkpoint:

# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
async_save: False

# Mbridge config extension.
mbridge_config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't add these yaml configs.

memory_efficient: true
distributed_filesystem: true

# optimizer configs
optim:
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class CheckpointConfig(BaseConfig):
save_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
load_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
async_save: bool = False
mbridge_config: Optional[dict[str, Any]] = field(default_factory=dict)


@dataclass
Expand Down
5 changes: 5 additions & 0 deletions verl/trainer/config/critic/critic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ checkpoint:
# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
async_save: False

# Mbridge config extension.
mbridge_config:
memory_efficient: true
distributed_filesystem: true

# profile the critic model in `update_critic`
profiler:

Expand Down
4 changes: 4 additions & 0 deletions verl/trainer/config/sft_trainer_engine.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ checkpoint:

# For more flexibility, you can specify the contents to load from the checkpoint.
load_contents: ${checkpoint.save_contents}
# Mbridge config extension.
mbridge_config:
memory_efficient: true
distributed_filesystem: true

trainer:
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
Expand Down
8 changes: 2 additions & 6 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger)
hf_ckpt_path = get_hf_model_checkpoint_path(local_path)
if self.vanilla_bridge:
self.bridge.save_weights(
self.model, hf_ckpt_path, distributed_filesystem=True, memory_efficient=True
)
self.bridge.save_weights(self.model, hf_ckpt_path, **(self.checkpoint_config.mbridge_config or {}))
else:
self.bridge.save_hf_weights(self.model, hf_ckpt_path)

Expand Down Expand Up @@ -571,9 +569,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
if self.bridge is not None:
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
if self.vanilla_bridge:
self.bridge.save_weights(
self.model, hf_model_ckpt_path, distributed_filesystem=True, memory_efficient=True
)
self.bridge.save_weights(self.model, hf_model_ckpt_path, **(self.checkpoint_config.mbridge_config or {}))
else:
self.bridge.save_hf_weights(self.model, hf_model_ckpt_path)
else:
Expand Down