Skip to content

Commit ad404cb

Browse files
albertcityalbertyi
andauthored
[megatron] fix: add protections for logits_processor_args.pop("loss_mask"), which may cause the forward_fn of value net collapse (verl-project#5204)
### What does this PR do? Fix a bug in `gpt_model_forward_no_padding`. The `MegatronEngineWithValueHead` class fails to pass `logits_processor_args` to `forward_fn`, causing a crash when `gpt_model_forward_no_padding` attempts to pop the `loss_mask`. ### Test > No need. ### Design & Code Changes > add `if logits_processor_args and "loss_mask" in logits_processor_args:` check before try to `logits_processor_args.pop("loss_mask")` Co-authored-by: albertyi <albertyi@tencent.com>
1 parent 2320603 commit ad404cb

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

verl/models/mcore/model_forward.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def gptmodel_forward_no_padding(
198198
}
199199
model_kwargs["labels"] = args["label"].contiguous()
200200
model_kwargs["loss_mask"] = args["loss_mask"].contiguous()
201-
logits_processor_args.pop("loss_mask")
201+
if logits_processor_args and 'loss_mask' in logits_processor_args:
202+
logits_processor_args.pop("loss_mask")
202203

203204
# For VLM model, need to pass bshd format `input_ids` and `attention_mask`.
204205
attention_mask = None
@@ -251,7 +252,8 @@ def gptmodel_forward_no_padding(
251252
}
252253
model_kwargs["labels"] = args["label"].contiguous()
253254
model_kwargs["loss_mask"] = args["loss_mask"].contiguous()
254-
logits_processor_args.pop("loss_mask")
255+
if logits_processor_args and 'loss_mask' in logits_processor_args:
256+
logits_processor_args.pop("loss_mask")
255257

256258
output_orig = model(
257259
input_ids=input_ids_bshd,

0 commit comments

Comments
 (0)