Skip to content

Commit e19c443

Browse files
author
User
committed
fix: explicitly delete forward_data_store to prevent GPU memory leak
On non-last pipeline stages, forward_data_store accumulates GPU tensors from microbatch outputs that are never transferred to rollout_data. These tensors were held in memory until the local variable went out of scope, which in long-running training loops could delay GPU memory reclamation. Explicitly delete forward_data_store after its data has been fully consumed to release references to these tensors as early as possible.
1 parent 0257bd6 commit e19c443

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

slime/backends/megatron_utils/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,10 @@ def forward_step(
293293
origin_values[origin_index] = value
294294
values = origin_values
295295
rollout_data[f"{store_prefix}{key}"] = values
296+
297+
# 显式释放 forward_data_store 以避免显存泄漏
298+
del forward_data_store
299+
296300
return rollout_data
297301

298302

0 commit comments

Comments
 (0)