Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions verl/models/transformers/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import functools
import logging
import os
import inspect
from dataclasses import dataclass
from typing import Optional

Expand Down Expand Up @@ -373,3 +374,91 @@ def patched_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# Apply the patch
Qwen3VLMoeTextSparseMoeBlock.forward = patched_forward
logger.info("Monkey patched Qwen3VLMoeTextSparseMoeBlock.forward to fix router_weights bug")


def patch_qwen3_vl_moe_fast_pos_embed_interpolate():
"""
Monkey patch to fix FSDP param_offload device mismatch on NPU.
When FSDP offloads params, self.pos_embed.weight.device falls back to CPU,
causing indices to be generated on CPU while computations are executed on NPU.
"""
try:
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
except ImportError:
return

def patched_fast_pos_embed_interpolate(self, grid_thw):
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]

idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]

for t, h, w in zip(grid_ts, grid_hs, grid_ws):
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
Comment on lines +397 to +398
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The steps argument of torch.linspace must be an integer, but h and w are 0-dimensional tensors from iterating over grid_hs and grid_ws. This will raise a TypeError and cause the program to crash. You should use .item() to convert them to Python integers.

Additionally, performing these calculations on the CPU within the loop and then transferring to the target device can be inefficient. Consider performing the computations directly on the grid_thw.device to avoid unnecessary data transfers between CPU and NPU.

Suggested change
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h.item())
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w.item())


h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)

dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor

base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side

indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]

weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]

for i in range(4):
idx_list[i].extend(indices[i].tolist())
weight_list[i].extend(weights[i].tolist())

# BUG FIX: Use grid_thw.device instead of self.pos_embed.weight.device
# grid_thw is guaranteed to be on the correct execution device (NPU) during the forward pass.
target_device = grid_thw.device

idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=target_device)
weight_tensor = torch.tensor(
weight_list, dtype=self.pos_embed.weight.dtype, device=target_device
)
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]

patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])

patch_pos_embeds_permute = []
merge_size = self.config.spatial_merge_size
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds

# Dynamically inject the patch into any Vision Class holding this method
for name, obj in inspect.getmembers(modeling_qwen3_vl_moe, inspect.isclass):
if hasattr(obj, 'fast_pos_embed_interpolate'):
setattr(obj, 'fast_pos_embed_interpolate', patched_fast_pos_embed_interpolate)
logger.info(f"Monkey patched {name}.fast_pos_embed_interpolate to fix FSDP param_offload NPU device bug")


# Execute patches upon module import
patch_qwen3_vl_moe_sparse_moe_block_forward()
patch_qwen3_vl_moe_fast_pos_embed_interpolate()