Skip to content

fix: match dtype when assigning position_ids in get_rope_index#109

Open
auto-yun wants to merge 2 commits intoISEEKYAN:mainfrom
auto-yun:fix/position-ids-dtype-mismatch
Open

fix: match dtype when assigning position_ids in get_rope_index#109
auto-yun wants to merge 2 commits intoISEEKYAN:mainfrom
auto-yun:fix/position-ids-dtype-mismatch

Conversation

@auto-yun
Copy link
Copy Markdown
Contributor

Problem

When input_ids are int32, position_ids defaults to torch.long (int64). The assignment:

position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)

only moves llm_positions to the correct device but does not cast its dtype, causing a dtype mismatch error during assignment.

Fix

Pass dtype=position_ids.dtype in the .to() call to ensure the dtype is always consistent:

position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
    device=position_ids.device, dtype=position_ids.dtype
)

Files Changed

  • mbridge/models/qwen3_5/rope_utils.py
  • mbridge/models/qwen3_vl/rope_utils.py

When input_ids are int32, position_ids defaults to long (int64), causing
a dtype mismatch during assignment. Fix by explicitly passing
dtype=position_ids.dtype in the .to() call for both qwen3_5 and qwen3_vl
rope_utils.py.
Same dtype mismatch issue as qwen3_5 and qwen3_vl: when input_ids are
int32, position_ids defaults to long (int64), causing assignment errors.
Fix by passing dtype=position_ids.dtype in .to() for:
- mbridge/models/glm4_vl/vl_mixin.py
- mbridge/models/qwen2_5_vl/rope_utils.py
- mbridge/models/qwen3_omni_moe/rope_utils.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant