Skip to content
Open
4 changes: 4 additions & 0 deletions verl/trainer/config/actor/actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ checkpoint:

# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
async_save: False

# Mbridge config extension.
mbridge_config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't add these yaml configs.

memory_efficient: False

# optimizer configs
optim:
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class CheckpointConfig(BaseConfig):
save_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
load_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
async_save: bool = False
mbridge_config: Optional[dict[str, Any]] = field(default_factory=dict)


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions verl/trainer/config/critic/critic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ checkpoint:
# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
async_save: False

# Mbridge config extension.
mbridge_config:
memory_efficient: False

# profile the critic model in `update_critic`
profiler:

Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/sft_trainer_engine.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ checkpoint:

# For more flexibility, you can specify the contents to load from the checkpoint.
load_contents: ${checkpoint.save_contents}
# Mbridge config extension.
mbridge_config:
memory_efficient: False

trainer:
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
Expand Down
4 changes: 2 additions & 2 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
hf_ckpt_path = get_hf_model_checkpoint_path(local_path)
if self.vanilla_bridge:
self.bridge.save_weights(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommend to modify like this:


import inspect
extended_args = {}
if "distributed_filesystem" in inspect.signature(self.bridge.save_weights):
    extended_args["distributed_filesystem"] = True
    extended_args["memory_efficient"] = True
self.bridge.save_weights(self.model, hf_ckpt_path, **extended_args)

self.model, hf_ckpt_path, distributed_filesystem=True, memory_efficient=True
self.model, hf_ckpt_path, **self.checkpoint_config.mbridge_config
)
else:
self.bridge.save_hf_weights(self.model, hf_ckpt_path)
Expand Down Expand Up @@ -572,7 +572,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
if self.vanilla_bridge:
self.bridge.save_weights(
self.model, hf_model_ckpt_path, distributed_filesystem=True, memory_efficient=True
self.model, hf_model_ckpt_path, **self.checkpoint_config.mbridge_config
)
else:
self.bridge.save_hf_weights(self.model, hf_model_ckpt_path)
Expand Down