diff --git a/verl/models/transformers/qwen3_vl.py b/verl/models/transformers/qwen3_vl.py index 972848a1a08..55888182c96 100644 --- a/verl/models/transformers/qwen3_vl.py +++ b/verl/models/transformers/qwen3_vl.py @@ -15,6 +15,7 @@ import functools import logging import os +import inspect from dataclasses import dataclass from typing import Optional @@ -373,3 +374,91 @@ def patched_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Apply the patch Qwen3VLMoeTextSparseMoeBlock.forward = patched_forward logger.info("Monkey patched Qwen3VLMoeTextSparseMoeBlock.forward to fix router_weights bug") + + +def patch_qwen3_vl_moe_fast_pos_embed_interpolate(): + """ + Monkey patch to fix FSDP param_offload device mismatch on NPU. + When FSDP offloads params, self.pos_embed.weight.device falls back to CPU, + causing indices to be generated on CPU while computations are executed on NPU. + """ + try: + from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe + except ImportError: + return + + def patched_fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + # BUG FIX: Use grid_thw.device instead of self.pos_embed.weight.device + # grid_thw is guaranteed to be on the correct execution device (NPU) during the forward pass. + target_device = grid_thw.device + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=target_device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=target_device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + # Dynamically inject the patch into any Vision Class holding this method + for name, obj in inspect.getmembers(modeling_qwen3_vl_moe, inspect.isclass): + if hasattr(obj, 'fast_pos_embed_interpolate'): + setattr(obj, 'fast_pos_embed_interpolate', patched_fast_pos_embed_interpolate) + logger.info(f"Monkey patched {name}.fast_pos_embed_interpolate to fix FSDP param_offload NPU device bug") + + +# Execute patches upon module import +patch_qwen3_vl_moe_sparse_moe_block_forward() +patch_qwen3_vl_moe_fast_pos_embed_interpolate()