-
Notifications
You must be signed in to change notification settings - Fork 64
Support for gradient_checkpointing in Vision Encoder to Reduce Memory Usage with High-Resolution Images on verl #114
Description
Thank you for the great work on verl. I'd like to raise a question regarding memory optimization for the vision encoder during training.
Environment & Model:
I'm training a Qwen3.5-35B-A3B (or similar vision-language model) using verl, and I encounter an OOM (Out of Memory) error when using high-resolution images (5000 tokens for an image). The error occurs at the following line:
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/single_controller/ray/base.py", line 932, in func
(TaskRunner pid=157926) return getattr(self.worker_dict[key], name)(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/single_controller/base/decorator.py", line 427, in inner
(TaskRunner pid=157926) return func(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/utils/profiler/profile.py", line 173, in wrapper
(TaskRunner pid=157926) return func(self_instance, *args, **kwargs_inner)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/workers/engine_workers.py", line 67, in wrapper
(TaskRunner pid=157926) return func(self, data, *args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/workers/engine_workers.py", line 633, in update_actor
(TaskRunner pid=157926) output = self.actor.train_mini_batch(data=data)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/single_controller/base/decorator.py", line 427, in inner
(TaskRunner pid=157926) return func(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/workers/engine_workers.py", line 293, in train_mini_batch
(TaskRunner pid=157926) actor_output = self.train_batch(mini_batch_td)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/single_controller/base/decorator.py", line 427, in inner
(TaskRunner pid=157926) return func(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/workers/engine_workers.py", line 341, in train_batch
(TaskRunner pid=157926) output = self.engine.train_batch(data, loss_function=self.loss_fn)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/workers/engine/base.py", line 126, in train_batch
(TaskRunner pid=157926) outputs = self.forward_backward_batch(data, loss_function, forward_only=False)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/workers/engine/megatron/transformer_impl.py", line 621, in forward_backward_batch
(TaskRunner pid=157926) losses_reduced = forward_backward_func(
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/megatron/core/pipeline_parallel/schedules.py", line 2202, in forward_backward_pipelining_without_interleaving
(TaskRunner pid=157926) output_tensor, num_tokens = forward_step(
(TaskRunner pid=157926) ^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/megatron/core/pipeline_parallel/schedules.py", line 423, in forward_step
(TaskRunner pid=157926) output_tensor, loss_func = forward_step_func(data_iterator, model)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/workers/engine/megatron/transformer_impl.py", line 862, in forward_step
(TaskRunner pid=157926) output = forward_fn(
(TaskRunner pid=157926) ^^^^^^^^^^^
(TaskRunner pid=157926) File "/sls-log/ray/session_2026-04-02_01-43-26_754436_1157/runtime_resources/working_dir_files/_ray_pkg_0b1fea510e48d28d/verl/models/mcore/model_forward.py", line 333, in gptmodel_forward_no_padding
(TaskRunner pid=157926) output_orig = model(
(TaskRunner pid=157926) ^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
(TaskRunner pid=157926) return self._call_impl(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
(TaskRunner pid=157926) return forward_call(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/megatron/core/distributed/data_parallel_base.py", line 22, in forward
(TaskRunner pid=157926) return self.module(*inputs, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
(TaskRunner pid=157926) return self._call_impl(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
(TaskRunner pid=157926) return forward_call(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/module.py", line 489, in forward
(TaskRunner pid=157926) outputs = self.module(*inputs, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
(TaskRunner pid=157926) return self._call_impl(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
(TaskRunner pid=157926) return forward_call(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/mbridge/models/qwen3_5/model.py", line 272, in forward
(TaskRunner pid=157926) vision_embeds = self.vision_model(
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
(TaskRunner pid=157926) return self._call_impl(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
(TaskRunner pid=157926) return forward_call(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 917, in wrapper
(TaskRunner pid=157926) output = func(self, *args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/transformers/utils/output_capturing.py", line 253, in wrapper
(TaskRunner pid=157926) outputs = func(self, *args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py", line 1356, in forward
(TaskRunner pid=157926) hidden_states = blk(
(TaskRunner pid=157926) ^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/transformers/modeling_layers.py", line 93, in __call__
(TaskRunner pid=157926) return super().__call__(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
(TaskRunner pid=157926) return self._call_impl(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
(TaskRunner pid=157926) return forward_call(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py", line 1170, in forward
(TaskRunner pid=157926) hidden_states = hidden_states + self.attn(
(TaskRunner pid=157926) ^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
(TaskRunner pid=157926) return self._call_impl(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
(TaskRunner pid=157926) return forward_call(*args, **kwargs)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py", line 1098, in forward
(TaskRunner pid=157926) query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py", line 1064, in apply_rotary_pos_emb_vision
(TaskRunner pid=157926) q_embed = (q * cos) + (rotate_half(q) * sin)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/local/lib/python3.12/dist-packages/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py", line 635, in rotate_half
(TaskRunner pid=157926) return torch.cat((-x2, x1), dim=-1)
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 88.00 MiB. GPU 0 has a total capacity of 79.18 GiB of which 66.50 MiB is free. Including non-PyTorch memory, this process has 73.29 GiB memory in use. Process 53252 has 5.79 GiB memory in use. Of the allocated memory 65.88 GiB is allocated by PyTorch, and 183.60 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
(TaskRunner pid=157926) Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::WorkerDict.actor_rollout_ref_update_actor() (pid=51246, ip=33.3.169.18, actor_id=0522e0c1f67ed72a5713373e0e000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f2b8801f620>)
(TaskRunner pid=157926) File "/usr/lib/python3.12/concurrent/futures/_base.py", line 456, in result
(TaskRunner pid=157926) return self.__get_result()
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
(TaskRunner pid=157926) raise self._exception
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=157926) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Problem:
It seems that gradient_checkpointing is not currently enabled for the vision encoder (ViT) in verl's training pipeline. In contrast, ms-swift provides a dedicated vit_gradient_checkpointing option (see ms-swift documentation) that enables gradient checkpointing specifically for the vision encoder, which significantly reduces memory consumption.
Question / Feature Request:
Would it be possible to add support for enabling gradient_checkpointing on the vision encoder in verl? Something like:
vision_model.enable_gradient_checkpointing()or exposing a config option such as vit_gradient_checkpointing=True similar to what ms-swift offers.
This would be very helpful for training with larger image resolutions without running into memory issues. I'd appreciate any guidance on whether this is feasible or if there's an existing workaround.
Thank you in advance for your time and support!