Commit 74638b9
[fsdp] fix: add aggressive_empty_cache at end of init_model to prevent vLLM OOM (verl-project#5384)
### What does this PR do?
Adds `aggressive_empty_cache(force_sync=True)` at the end of
`ActorRolloutRefWorker.init_model()` to prevent vLLM from OOMing at
startup when colocated on the same GPUs as FSDP.
Related: verl-project#4229, verl-project#4257 (stale)
After the removal of `ExternalZeroMQDistributedExecutor`, vLLM runs in
separate MP worker processes instead of inside the FSDP worker process.
During `init_model()`, PyTorch's CUDA allocator reserves large transient
blocks for full-model loading before FSDP sharding and
`sync_module_states` broadcasting. After init, these blocks are no
longer needed but remain cached by the allocator (`cudaMalloc`'d, not
`cudaFree`'d). Since vLLM now runs in a separate process with its own
allocator, it cannot reuse these cached blocks — `cudaMemGetInfo`
reports them as "used", and vLLM fails its `gpu_memory_utilization`
check.
Previous attempts to fix this (verl-project#4257) went stale. This approach is
simpler: one line, no guards needed, and is a no-op when there is
nothing to free.
### Checklist Before Starting
- [x] Search for similar PRs:
[aggressive_empty_cache](https://github.com/verl-project/verl/pulls?q=aggressive_empty_cache),
[OOM fsdp vllm
init](https://github.com/verl-project/verl/pulls?q=OOM+fsdp+vllm+init)
- [x] Format the PR title as `[{modules}] {type}: {description}`
### Test
This cannot be tested in CI because the OOM is a cross-process CUDA
memory visibility issue that requires colocated FSDP + vLLM on the same
physical GPU to reproduce.
- The call site is exercised by the existing
`tests/workers/test_fsdp_workers.py`
- Validated experimentally with a colocated FSDP + vLLM training run (8B
VLM, 8x H200, hybrid mode)
- The fix is a no-op when there is no cached memory to free, so it is
safe in all configurations
### API and Usage Example
No API changes. The fix is automatic.
### Design & Code Changes
One line added at the end of `ActorRolloutRefWorker.init_model()` in
`verl/workers/fsdp_workers.py`:
```python
# Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo
aggressive_empty_cache(force_sync=True)
```
This pattern already exists in the codebase:
- `megatron_workers.py:677` — `empty_cache()` at end of `init_model`
- `fsdp_workers.py:742` — `aggressive_empty_cache` during
`rollout_mode()` context switch
- `engine_workers.py:671` — `aggressive_empty_cache` during rollout mode
The FSDP worker was the only one missing it at init time.
### Checklist Before Submitting
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
all hooks pass.
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs). —
N/A, no user-facing changes.
- [x] Add unit or end-to-end test(s) — not feasible: requires
multi-process colocated GPU setup to reproduce; the fix is exercised by
existing `test_fsdp_workers.py`.
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1).
- [x] Not related to `recipe` submodule.1 parent aed4297 commit 74638b9
File tree
3 files changed
+8
-1
lines changed- verl/workers
3 files changed
+8
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
583 | 583 | | |
584 | 584 | | |
585 | 585 | | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
586 | 589 | | |
587 | 590 | | |
588 | 591 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
981 | 981 | | |
982 | 982 | | |
983 | 983 | | |
| 984 | + | |
| 985 | + | |
| 986 | + | |
984 | 987 | | |
985 | 988 | | |
986 | 989 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
674 | 674 | | |
675 | 675 | | |
676 | 676 | | |
677 | | - | |
| 677 | + | |
| 678 | + | |
678 | 679 | | |
679 | 680 | | |
680 | 681 | | |
| |||
0 commit comments