Skip to content

[Optimization] Support multimodal runner for image/video feature processing#7485

Open
xiaoxiaohehe001 wants to merge 3 commits intoPaddlePaddle:developfrom
xiaoxiaohehe001:support_mm_runner
Open

[Optimization] Support multimodal runner for image/video feature processing#7485
xiaoxiaohehe001 wants to merge 3 commits intoPaddlePaddle:developfrom
xiaoxiaohehe001:support_mm_runner

Conversation

@xiaoxiaohehe001
Copy link
Copy Markdown
Collaborator

Motivation

支持多模态 runner 中图像/视频特征处理流程,增强 GPU Model Runner 对预编码多模态特征的处理能力。

Changes

fastdeploy/worker/gpu_model_runner.py

  • 新增对 image_feature_urls 预编码图像特征的处理逻辑,支持直接传入已编码的 image embedding,跳过 vision encoder 计算
  • 新增 image_grid_thwsvideo_features / video_grid_thws 的传递与管理
  • Prefill 阶段增加多模态 attention mask offsets 的计算与设置(attn_mask_offsetsdecode_states
  • 调用 update_attn_mask_offsets 在 forward 前更新 attention mask
  • attn_mask_offsets 传入 forward meta,供模型推理使用
  • 对 prefill 请求按 idx 排序,确保处理顺序一致性
  • Forward 结束后主动清空 image_features / video_features 等中间状态,防止内存泄漏

fastdeploy/worker/input_batch.py

  • 新增 image_grid_thwsvideo_featuresvideo_grid_thwsvideo_infinity_scales 字段
  • 新增 decode_statesattn_mask_offsetsattn_mask_offsets_full tensor 初始化
  • swap_dataresetresize 等操作中补齐新增字段的维护逻辑
  • 补充 generated_modality 在 swap 和 reset 中的处理(之前遗漏)

fastdeploy/engine/sched/resource_manager_v1.py

  • 移除 Ernie5 架构下多模态请求的特殊调度限制(get_enough_request),统一调度逻辑

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 19, 2026 13:36
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 19, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在增强多模态 GPU Model Runner:支持在 runner 侧直接接收/传递预编码的图像/视频特征,并在 prefill 阶段引入 attention mask offsets 的计算与 forward 透传;同时调整调度器,移除 Ernie5 多模态请求的特殊限制以统一调度逻辑。

Changes:

  • gpu_model_runner.py 中新增对预编码图像特征的处理、prefill 阶段 attention mask offsets 的更新,并将相关字段传入模型 forward meta / model_inputs
  • input_batch.py 中为多模态补齐新增字段与 attn mask 相关 tensor 的初始化/维护逻辑
  • resource_manager_v1.py 中移除 Ernie5 多模态请求的特殊调度限制,统一调度流程

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 9 comments.

File Description
fastdeploy/worker/input_batch.py 新增/维护多模态相关字段与 attn mask offsets/ decode_states buffer
fastdeploy/worker/gpu_model_runner.py 处理预编码图像特征、计算并传递 attn_mask_offsets,清理中间特征状态
fastdeploy/engine/sched/resource_manager_v1.py 移除 Ernie5 多模态请求的特殊调度限制逻辑

Comment thread fastdeploy/worker/input_batch.py
Comment thread fastdeploy/worker/input_batch.py
Comment thread fastdeploy/worker/gpu_model_runner.py
Comment thread fastdeploy/worker/gpu_model_runner.py
Comment thread fastdeploy/worker/gpu_model_runner.py Outdated
Comment thread fastdeploy/worker/gpu_model_runner.py Outdated
Comment thread fastdeploy/worker/gpu_model_runner.py
Comment thread fastdeploy/worker/input_batch.py
Comment thread fastdeploy/worker/input_batch.py
PaddlePaddle-bot

This comment was marked as outdated.

@xiaoxiaohehe001 xiaoxiaohehe001 changed the title [NewFeature] Support multimodal runner for image/video feature processing [Optimization] Support multimodal runner for image/video feature processing Apr 20, 2026
PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings April 20, 2026 08:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 8 comments.

Comment thread fastdeploy/worker/gpu_model_runner.py
Comment thread fastdeploy/worker/input_batch.py
Comment thread fastdeploy/worker/input_batch.py
Comment thread fastdeploy/worker/gpu_model_runner.py
Comment thread fastdeploy/worker/gpu_model_runner.py
Comment thread fastdeploy/worker/gpu_model_runner.py
Comment thread fastdeploy/worker/gpu_model_runner.py
Comment thread fastdeploy/worker/gpu_model_runner.py
Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-20 17:05:46

📋 Review 摘要

PR 概述:支持多模态 runner 中图像/视频预编码特征处理流程,增强 GPU Model Runner 对 image_feature_urls 等预编码多模态特征的处理能力,同时统一 Ernie5 架构的调度逻辑。
变更范围worker/gpu_model_runner.pyworker/input_batch.pyengine/sched/resource_manager_v1.py
影响面 Tag[Engine] [Scheduler]

问题

级别 文件 概述
🔴 Bug gpu_model_runner.py:840 insert_tasks_v1request.multimodal_inputs 可能为 None,直接调用 .get() 会抛 AttributeError
🟡 建议 gpu_model_runner.py:649 Shape 不匹配仅 logger.error 但继续处理,不匹配的 tensor 仍会被 concat 导致下游错误
🟡 建议 gpu_model_runner.py:657 logger.info 在每个请求循环内调用,生产环境会产生大量日志,建议改为 logger.debug

总体评价

整体实现思路清晰:input_batch.py 中新增字段的 init/swap/reset 维护完整,调度侧移除 Ernie5 特殊限制的简化合理。主要风险在 insert_tasks_v1 中缺少对 multimodal_inputs 为 None 的防御,在多模态模型处理纯文本请求时会导致运行时崩溃,建议修复后合入。


if self.enable_mm:
self.share_inputs["decode_states"][idx, 0] = 0
inputs = request.multimodal_inputs
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug request.multimodal_inputs 可能为 None,导致下一行 inputs.get(...) 抛出 AttributeError

在多模态模型(self.enable_mm=True)中,纯文本请求的 multimodal_inputs 默认值为 None(见 Request.__init__ 定义)。此处未做 None 检查就直接调用 .get() 方法。

建议修复:

if self.enable_mm:
    self.share_inputs["decode_states"][idx, 0] = 0
    inputs = request.multimodal_inputs
    # mm attention_mask
    attn_offset_len = prefill_end_index - prefill_start_index
    if inputs is None or inputs.get("attention_mask_offset", None) is None:
        attention_mask_offset_slice = np.arange(prefill_start_index, prefill_end_index, dtype=np.int32)
    else:
        ...

if isinstance(image_feature[0], paddle.Tensor) and len(image_feature[0].shape) == 2:
# Enable encode vision_embedding
for image_feature_tensor in image_feature:
if image_feature_tensor.shape[1] != self.fd_config.model_config.hidden_size:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 Shape 不匹配时仅记录 error 日志但未中断处理,后续 .cuda()paddle.concat 仍会以错误的 tensor 继续执行。

image_feature_tensor.shape[1] != hidden_size 时,应考虑 continue 跳过该 feature 或抛出异常,避免将维度不匹配的 tensor 拼接后传入模型导致难以排查的推理错误。

image_features_gpu = [vf.cuda() for vf in image_feature]
image_embeds = paddle.concat(image_features_gpu, axis=0)
multi_vision_inputs["image_features"].append(image_embeds)
logger.info("Enable Encode image embedding.")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 logger.info 位于 per-request 循环内,批量推理时每个请求都会打印一次。

建议改为 logger.debug,避免在生产环境中产生大量重复日志影响性能和可读性。同理适用于第 660 行的 logger.info("Disable Encode image embedding.")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants