Skip to content

[BUG- Qwen3-VL]RuntimeError: Expected all tensors to be on the same device, but found at least two devices, npu:0 and cpu! #5571

@iceflysnow

Description

@iceflysnow

System Info

verl version: 0.8.0.dev0
Ascend 910B * 4
models: Qwen3-VL-30B-A3B-Instruct

----------Python Info----------
Version : 3.11.13
Compiler : GCC 11.4.0
Build : ('main', 'Nov 2 2025 08:46:33')
Arch : ('64bit', '')
------------Pip Info-----------
Version : 26.0.1
Directory : /usr/local/python3.11.13/lib/python3.11/site-packages/pip
vllm : 0.11.0+empty
sglang : not found.
ray : 2.54.0
torch : 2.7.1
----------verl Info-----------
Version : 0.8.0.dev
Directory : /data/verl-new/verl
Commit Hash : 2703d73
----------Platform Info----------
Platform : Linux-5.15.0-25-generic-aarch64-with-glibc2.35
system : Linux
node : verl-qwen3-vl-30b-1-master-0
release : 5.15.0-25-generic
version : #25-Ubuntu SMP Wed Mar 30 15:57:31 UTC 2022
----------Environment----------
CUDA is not available.
----------System Info----------
Failed to execute nvidia-smi command.
CPU Memory : 2011.60 GB
GPU Count : 0

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

run : /data/verl-new/examples/grpo_trainer/run_qwen3_vl_30b_vllm_fsdp_npu.sh

Error Info:
Traceback (most recent call last):
File "/data/verl-new/verl/trainer/main_ppo.py", line 45, in main
run_ppo(config)
File "/data/verl-new/verl/trainer/main_ppo.py", line 99, in run_ppo
ray.get(runner.run.remote(config))
File "/usr/local/python3.11.13/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/ray/_private/worker.py", line 2981, in get
values, debugger_breakpoint = worker.get_objects(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/ray/_private/worker.py", line 1012, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): [36mray::TaskRunner.run()[39m (pid=670053, ip=10.119.5.132, actor_id=fffe55f82f5a8d4901a57df302000000, repr=<main_ppo.TaskRunner object at 0xffff861e0690>)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/trainer/main_ppo.py", line 352, in run
trainer.fit()
File "/data/verl-new/verl/trainer/ppo/ray_trainer.py", line 1440, in fit
ref_log_prob = self._compute_ref_log_prob(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/trainer/ppo/ray_trainer.py", line 1128, in _compute_ref_log_prob
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/single_controller/ray/base.py", line 55, in call
output = ray.get(output)
^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError(RuntimeError): [36mray::WorkerDict.ref_compute_ref_log_prob()[39m (pid=670905, ip=10.119.5.132, actor_id=95e2ee817aebc70ffb9b96b102000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0xffefa35c5ed0>)
File "/usr/local/python3.11.13/lib/python3.11/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/single_controller/ray/base.py", line 932, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/single_controller/base/decorator.py", line 462, in inner
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/utils/transferqueue_utils.py", line 314, in dummy_inner
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/utils/profiler/profile.py", line 173, in wrapper
return func(self_instance, *args, **kwargs_inner)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/workers/fsdp_workers.py", line 1071, in compute_ref_log_prob
outputs = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/utils/profiler/performance.py", line 105, in f
return self.log(decorated_function, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/utils/profiler/performance.py", line 118, in log
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/workers/actor/dp_actor.py", line 472, in compute_log_prob
outputs = self._forward_micro_batch(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/workers/actor/dp_actor.py", line 244, in _forward_micro_batch
output = self.actor_module(
^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl
return inner()
^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1805, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/models/transformers/qwen3_vl.py", line 317, in forward_with_torch_backend
outputs = self.model(input_ids, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/models/transformers/qwen3_vl.py", line 281, in qwen3_vl_base_forward
input_kwargs = _get_input_embeds(
^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/models/transformers/qwen3_vl.py", line 186, in _get_input_embeds
image_embeds, deepstack_image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py", line 769, in forward
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/verl-new/verl/models/transformers/qwen3_vl.py", line 52, in patched_method
return original_method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py", line 738, in fast_pos_embed_interpolate
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl
return inner()
^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1805, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 190, in forward
return F.embedding(
^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/functional.py", line 2516, in embedding
return handle_torch_function(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/overrides.py", line 1721, in handle_torch_function
result = mode.torch_function(public_api, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/utils/_device.py", line 104, in torch_function
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.13/lib/python3.11/site-packages/torch/nn/functional.py", line 2551, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, npu:0 and cpu! (when checking argument for argument indices in method wrapper_NPU__embedding)

Expected behavior

RuntimeError don't happen.

Bug Fix:
Add below codes to verl/models/transformers/qwen3_vl.py , can solve this issue, have been tested.

#Begin
import inspect
def patch_qwen3_vl_moe_fast_pos_embed_interpolate():
"""
Monkey patch to fix FSDP param_offload device mismatch on NPU.
When FSDP offloads params, self.pos_embed.weight.device falls back to CPU,
causing indices to be generated on CPU while computations are executed on NPU.
"""
try:
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
except ImportError:
return

def patched_fast_pos_embed_interpolate(self, grid_thw):
    grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]

    idx_list = [[] for _ in range(4)]
    weight_list = [[] for _ in range(4)]

    for t, h, w in zip(grid_ts, grid_hs, grid_ws):
        h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
        w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)

        h_idxs_floor = h_idxs.int()
        w_idxs_floor = w_idxs.int()
        h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
        w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)

        dh = h_idxs - h_idxs_floor
        dw = w_idxs - w_idxs_floor

        base_h = h_idxs_floor * self.num_grid_per_side
        base_h_ceil = h_idxs_ceil * self.num_grid_per_side

        indices = [
            (base_h[None].T + w_idxs_floor[None]).flatten(),
            (base_h[None].T + w_idxs_ceil[None]).flatten(),
            (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
            (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
        ]

        weights = [
            ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
            ((1 - dh)[None].T * dw[None]).flatten(),
            (dh[None].T * (1 - dw)[None]).flatten(),
            (dh[None].T * dw[None]).flatten(),
        ]

        for i in range(4):
            idx_list[i].extend(indices[i].tolist())
            weight_list[i].extend(weights[i].tolist())

    # BUG FIX: Use grid_thw.device instead of self.pos_embed.weight.device
    # grid_thw is guaranteed to be on the correct execution device (NPU) during the forward pass.
    target_device = grid_thw.device
    
    idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=target_device)
    weight_tensor = torch.tensor(
        weight_list, dtype=self.pos_embed.weight.dtype, device=target_device
    )
    pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
    patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]

    patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])

    patch_pos_embeds_permute = []
    merge_size = self.config.spatial_merge_size
    for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
        pos_embed = pos_embed.repeat(t, 1)
        pos_embed = (
            pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
            .permute(0, 1, 3, 2, 4, 5)
            .flatten(0, 4)
        )
        patch_pos_embeds_permute.append(pos_embed)
    patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
    return patch_pos_embeds

# Dynamically inject the patch into any Vision Class holding this method
for name, obj in inspect.getmembers(modeling_qwen3_vl_moe, inspect.isclass):
    if hasattr(obj, 'fast_pos_embed_interpolate'):
        setattr(obj, 'fast_pos_embed_interpolate', patched_fast_pos_embed_interpolate)
        logger.info(f"Monkey patched {name}.fast_pos_embed_interpolate to fix FSDP param_offload NPU device bug")

==========================================

Execute patches upon module import

==========================================

patch_qwen3_vl_moe_sparse_moe_block_forward()
patch_qwen3_vl_moe_fast_pos_embed_interpolate()

#End

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions