Skip to content

[Cherry-Pick][Optimization] Support multimodal runner for image/video…#7576

Open
xiaoxiaohehe001 wants to merge 1 commit intoPaddlePaddle:release/2.6from
xiaoxiaohehe001:26_fd_runner
Open

[Cherry-Pick][Optimization] Support multimodal runner for image/video…#7576
xiaoxiaohehe001 wants to merge 1 commit intoPaddlePaddle:release/2.6from
xiaoxiaohehe001:26_fd_runner

Conversation

@xiaoxiaohehe001
Copy link
Copy Markdown
Collaborator

Motivation

支持多模态 runner 中图像/视频特征处理流程,增强 GPU Model Runner 对预编码多模态特征的处理能力。

Changes

Cherry-Pick from #7485

fastdeploy/worker/gpu_model_runner.py

  • 新增对 image_feature_urls 预编码图像特征的处理逻辑,支持直接传入已编码的 image embedding,跳过 vision encoder 计算
  • 新增 image_grid_thwsvideo_features / video_grid_thws 的传递与管理
  • Prefill 阶段增加多模态 attention mask offsets 的计算与设置(attn_mask_offsetsdecode_states
  • 调用 update_attn_mask_offsets 在 forward 前更新 attention mask
  • attn_mask_offsets 传入 forward meta,供模型推理使用
  • 对 prefill 请求按 idx 排序,确保处理顺序一致性
  • Forward 结束后主动清空 image_features / video_features 等中间状态,防止内存泄漏

fastdeploy/worker/input_batch.py

  • 新增 image_grid_thwsvideo_featuresvideo_grid_thwsvideo_infinity_scales 字段
  • 新增 decode_statesattn_mask_offsetsattn_mask_offsets_full tensor 初始化
  • swap_dataresetresize 等操作中补齐新增字段的维护逻辑
  • 补充 generated_modality 在 swap 和 reset 中的处理(之前遗漏)

fastdeploy/engine/sched/resource_manager_v1.py

  • 移除 Ernie5 架构下多模态请求的特殊调度限制(get_enough_request),统一调度逻辑

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 23, 2026 06:18
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 23, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 从上游 Cherry-Pick,引入多模态 runner 对“预编码图像/视频特征(BOS URL 下载后直接注入 embedding)”的支持,并在 GPUModelRunner 的 prefill 路径补齐多模态 attention mask offsets 的计算/传递,以增强 VL 模型推理时对多模态输入的处理能力。

Changes:

  • GPUModelRunner 中新增对 image_feature_urls / image_features + image_grid_thws 的处理,并在 forward 前调用 update_attn_mask_offsets 生成 attn_mask_offsets
  • InputBatch 中新增多模态相关字段(image/video grid 与特征、decode_states、attn_mask_offsets 等)并在 swap/reset 中维护。
  • ResourceManagerV1 中移除 Ernie5 下多模态请求的特殊调度限制,统一调度逻辑。

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 9 comments.

File Description
fastdeploy/worker/input_batch.py 增加多模态字段与 attn mask 相关 buffer,并扩展 swap/reset 逻辑以支持新状态
fastdeploy/worker/gpu_model_runner.py 支持预编码特征注入,prefill 阶段更新 attn_mask_offsets,并将其传入 forward_meta
fastdeploy/engine/sched/resource_manager_v1.py 移除 Ernie5 多模态请求的特殊调度 gating,统一调度路径

Comment on lines +701 to +702
fill_paddle_tensor(self, "attn_mask_offsets", -1)
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reset_share_inputs 里对 attn_mask_offsets 做 fill(-1) 没问题,但需要确保 attn_mask_offsets 的预分配 shape 与 update_attn_mask_offsets 输出一致(token_num*2)。如果按当前初始化的较小 shape,会在 prepare_inputs 里 copy 时报错。建议在这里也同步按 *2 的 shape 维护。

Suggested change
fill_paddle_tensor(self, "attn_mask_offsets", -1)
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
attn_mask_token_num = max_num_seqs * self.model_config.max_model_len
attn_mask_offsets_shape = [attn_mask_token_num * 2]
attn_mask_offsets = getattr(self, "attn_mask_offsets", None)
if attn_mask_offsets is None or list(attn_mask_offsets.shape) != attn_mask_offsets_shape:
attn_mask_offsets_dtype = attn_mask_offsets.dtype if attn_mask_offsets is not None else "int32"
self.attn_mask_offsets = paddle.full(
shape=attn_mask_offsets_shape,
fill_value=-1,
dtype=attn_mask_offsets_dtype,
)
else:
fill_paddle_tensor(self, "attn_mask_offsets", -1)
attn_mask_offsets_full = getattr(self, "attn_mask_offsets_full", None)
if attn_mask_offsets_full is None or list(attn_mask_offsets_full.shape) != attn_mask_offsets_shape:
attn_mask_offsets_full_dtype = (
attn_mask_offsets_full.dtype if attn_mask_offsets_full is not None else "int32"
)
self.attn_mask_offsets_full = paddle.full(
shape=attn_mask_offsets_shape,
fill_value=-1,
dtype=attn_mask_offsets_full_dtype,
)
else:
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)

Copilot uses AI. Check for mistakes.
Comment on lines +937 to 939
swap_data(self.attn_mask_offsets, i1, i2)
swap_data(self.attn_mask_offsets_full, i1, i2)
swap_data(self.attn_mask_offsets_decoder, i1, i2)
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里新增 swap_data(self.attn_mask_offsets, i1, i2) 可能是错误的:attn_mask_offsets 是按 token 展平的一维 buffer(长度与 ids_remove_padding token_num 对齐,且 update_attn_mask_offsets 输出为 token_num*2),按 batch 维度交换 i1/i2 只会交换单个元素,无法保持与 cu_seqlens_q/query_start 的一致性,容易导致 attention mask offsets 错乱。建议不要在 reorder/swap 时交换该一维 buffer,而是在每次 pre_process 后统一重算并整段 copy_;或者把其存储改为按 batch 的二维布局后再实现正确的交换逻辑。

Suggested change
swap_data(self.attn_mask_offsets, i1, i2)
swap_data(self.attn_mask_offsets_full, i1, i2)
swap_data(self.attn_mask_offsets_decoder, i1, i2)
# Attention mask offset buffers may be token-flattened derived state
# rather than batch-aligned storage. Swapping a single element by
# batch index can break consistency with the flattened token layout.
# Keep them untouched here and let the later preprocessing stage
# rebuild them from the current batch layout.

Copilot uses AI. Check for mistakes.
Comment on lines +633 to +636
multi_vision_inputs["image_grid_thws"].extend(
inputs["image_grid_thws"][request.image_start : request.image_end]
)
image_feature = inputs["image_features"][request.image_start : request.image_end]
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里仅以 image_feature_urls 非空作为条件,但随后直接访问 inputs["image_grid_thws"]/inputs["image_features"]。在 ResourceManagerV1._download_features 只会填充 image_features(不保证存在 image_grid_thws),因此该分支可能触发 KeyError 或出现长度不一致导致切片错误。建议改为显式校验这两个字段存在且长度与 image_feature_urls 对齐;缺失时给 request 设置清晰的 error_message/error_code 或直接抛出异常。

Suggested change
multi_vision_inputs["image_grid_thws"].extend(
inputs["image_grid_thws"][request.image_start : request.image_end]
)
image_feature = inputs["image_features"][request.image_start : request.image_end]
image_feature_urls = inputs["image_feature_urls"]
image_grid_thws = inputs.get("image_grid_thws")
image_features = inputs.get("image_features")
image_start = request.image_start
image_end = request.image_end
if image_grid_thws is None or image_features is None:
raise ValueError(
"Missing multimodal input fields for image features: "
f"request_idx={request.idx}, "
f"has_image_feature_urls={image_feature_urls is not None}, "
f"has_image_features={image_features is not None}, "
f"has_image_grid_thws={image_grid_thws is not None}"
)
if not (
len(image_feature_urls) == len(image_features) == len(image_grid_thws)
):
raise ValueError(
"Mismatched multimodal input lengths: "
f"request_idx={request.idx}, "
f"image_feature_urls={len(image_feature_urls)}, "
f"image_features={len(image_features)}, "
f"image_grid_thws={len(image_grid_thws)}"
)
if not (0 <= image_start <= image_end <= len(image_feature_urls)):
raise ValueError(
"Invalid image slice range: "
f"request_idx={request.idx}, "
f"image_start={image_start}, "
f"image_end={image_end}, "
f"total_images={len(image_feature_urls)}"
)
multi_vision_inputs["image_grid_thws"].extend(
image_grid_thws[image_start:image_end]
)
image_feature = image_features[image_start:image_end]

Copilot uses AI. Check for mistakes.
Comment on lines +641 to +646
for image_feature_tensor in image_feature:
if image_feature_tensor.shape[1] != self.fd_config.model_config.hidden_size:
logger.error(
f"Shape mismatch: expected shape={self.fd_config.model_config.hidden_size}, \
but got {image_feature_tensor.shape}"
)
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

检测到 image_feature_tensor 的 hidden_size 不匹配时这里只是 logger.error 但继续执行,后续 concat/模型 forward 很可能因 shape 不一致直接崩溃或产生错误结果。建议在发现不匹配时立即 fail-fast(raise 异常或设置 request.error_message/error_code 并跳过该请求),并在错误信息中同时包含期望 shape(如 [*, hidden_size])与实际 shape。

Suggested change
for image_feature_tensor in image_feature:
if image_feature_tensor.shape[1] != self.fd_config.model_config.hidden_size:
logger.error(
f"Shape mismatch: expected shape={self.fd_config.model_config.hidden_size}, \
but got {image_feature_tensor.shape}"
)
expected_hidden_size = self.fd_config.model_config.hidden_size
for image_feature_tensor in image_feature:
if image_feature_tensor.shape[1] != expected_hidden_size:
error_message = (
f"Image feature hidden size mismatch for request idx={request.idx}: "
f"expected shape [*, {expected_hidden_size}], "
f"but got {list(image_feature_tensor.shape)}"
)
logger.error(error_message)
raise ValueError(error_message)

Copilot uses AI. Check for mistakes.
Comment on lines +364 to +365
self.attn_mask_offsets = paddle.full(
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里预分配的 attn_mask_offsets 长度只有 max_num_seqs * max_model_len,但 GPU 算子 update_attn_mask_offsets 的输出是 ids_remove_padding.shape[0] * 2(每个 token 两个 offset)。当前尺寸会导致后续 copy_ 时 shape 不匹配/潜在越界。建议按 ids_remove_padding 的最大长度预分配并乘以 2(例如 max_num_seqs * max_chunk_tokens * 2),并在 reset 时保持一致。

Suggested change
self.attn_mask_offsets = paddle.full(
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
attn_mask_token_capacity = self.scheduler_config.max_num_seqs * self.model_config.max_model_len
self.attn_mask_offsets = paddle.full(
shape=[attn_mask_token_capacity * 2],

Copilot uses AI. Check for mistakes.
self.share_inputs["is_block_step"],
self.share_inputs["decode_states"],
)
self.share_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False)
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_attn_mask_offsets 在 GPU 侧返回的 attn_mask_offsets 长度是 ids_remove_padding.shape[0] * 2(见算子实现),这里直接 copy_ 到 share_inputs["attn_mask_offsets"] 要求目标 buffer shape 完全一致。请确认 InputBatch/ProposerInputBatch 里 attn_mask_offsets 的预分配长度是 max_token_capacity*2,否则这里会在运行时因 shape 不匹配而报错。

Suggested change
self.share_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False)
attn_mask_offsets_buffer = self.share_inputs["attn_mask_offsets"]
attn_mask_offsets_len = attn_mask_offsets.shape[0]
if attn_mask_offsets_buffer.numel() < attn_mask_offsets_len:
raise RuntimeError(
"attn_mask_offsets buffer capacity is insufficient: "
f"required={attn_mask_offsets_len}, "
f"capacity={attn_mask_offsets_buffer.numel()}. "
"Please ensure the preallocated attn_mask_offsets buffer "
"has capacity for max_token_capacity * 2."
)
attn_mask_offsets_buffer[:attn_mask_offsets_len].copy_(attn_mask_offsets, False)

Copilot uses AI. Check for mistakes.
Comment on lines +647 to +648
image_features_gpu = [vf.cuda() for vf in image_feature]
image_embeds = paddle.concat(image_features_gpu, axis=0)
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里使用 vf.cuda() 会硬编码 CUDA 设备,并且逐个 tensor 迁移可能带来额外开销;同时本 runner 其他路径普遍用 .to(self.device) 保持设备一致。建议改为使用 image_feature_tensor.to(self.device)(或在 to_tensor 阶段统一放到目标 device),并尽量减少逐个拷贝。

Suggested change
image_features_gpu = [vf.cuda() for vf in image_feature]
image_embeds = paddle.concat(image_features_gpu, axis=0)
image_embeds = paddle.concat(image_feature, axis=0).to(self.device)

Copilot uses AI. Check for mistakes.
Comment on lines +183 to +185
# Note(Zhengshifeng) init video cache for VL model
self.video_cache = {}

Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.video_cache 在本文件中只初始化但未被任何逻辑读取/写入(search 仅命中这一处)。如果不是本 PR 范围内即将使用的字段,建议删除避免误导;如果后续会用到,建议至少在注释里说明其生命周期/访问路径,或在本 PR 中补齐使用点。

Suggested change
# Note(Zhengshifeng) init video cache for VL model
self.video_cache = {}

Copilot uses AI. Check for mistakes.
Comment on lines +628 to +633
if (
inputs is not None
and inputs.get("image_feature_urls", None) is not None
and len(inputs["image_feature_urls"]) > 0
):
multi_vision_inputs["image_grid_thws"].extend(
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的预编码特征路径(image_feature_urls/image_features + image_grid_thws)以及 prefill 阶段的 attention_mask_offset 填充/更新属于关键推理逻辑,但当前 tests/worker/test_gpu_model_runner.py 里没有覆盖该分支。建议补充单测:构造带 image_feature_urls 的 request(含/不含 image_grid_thws、shape 不匹配等),验证 process_mm_features 的输出类型与错误处理;同时覆盖 update_attn_mask_offsets + copy 的基本形状约束。

Copilot generated this review using guidance from repository custom instructions.
Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-23 14:25:09

📋 Review 摘要

PR 概述:为 GPU Model Runner 增加多模态(图像/视频)特征处理能力,支持预编码 image embedding 直传、attention mask offsets 计算,并移除 Ernie5 架构下多模态请求的特殊调度限制。
变更范围worker/gpu_model_runner.pyworker/input_batch.pyengine/sched/resource_manager_v1.py
影响面 Tag[Engine] [Models] [Scheduler]


📝 PR 规范检查

PR 标题包含 [Cherry-Pick][Optimization] 标签,格式符合规范;但 Cherry-Pick PR 应在标题末尾附上原 PR ID。

标题建议(可直接复制):

  • [Cherry-Pick][Optimization] Support multimodal runner for image/video features (#7485)

问题

级别 文件 概述
🔴 Bug gpu_model_runner.py:834 inputs 可能为 None,直接调用 .get() 会抛 AttributeError
🟡 建议 gpu_model_runner.py:184 self.video_cache = {} 初始化后未被使用
🟡 建议 input_batch.py:371 非 speculative 路径缺少 attn_mask_offsets_decoder 初始化,swap_states 会访问此字段
🟡 建议 input_batch.py:482 swap_data 块遗漏 attn_mask_offsets 的 swap,与 swap_states 逻辑不一致

总体评价

多模态特征处理和 attention mask offsets 计算的整体设计思路清晰,字段补全和 reset/swap 的覆盖较为系统。但存在一个潜在的 None 解引用 Bug(insert_tasks_v1multimodal_inputs 未做 None 检查即调用 .get()),以及初始化/swap 字段不完整的一致性问题,建议修复后合入。

length = prefill_end_index - prefill_start_index
if self.enable_mm:
self.share_inputs["decode_states"][idx, 0] = 0
inputs = request.multimodal_inputs
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug inputs 可能为 None,但此处直接调用 inputs.get(...) 未做 None 检查

request.multimodal_inputs 返回 None 时(非多模态请求进入 prefill 分支),inputs.get("attention_mask_offset", None) 会抛出 AttributeError: 'NoneType' object has no attribute 'get'

建议修复:

if self.enable_mm:
    self.share_inputs["decode_states"][idx, 0] = 0
    inputs = request.multimodal_inputs
    if inputs is not None:
        attn_offset_len = prefill_end_index - prefill_start_index
        if inputs.get("attention_mask_offset", None) is None:
            attention_mask_offset_slice = np.arange(...)
        ...

self.encoder_cache = None

# Note(Zhengshifeng) init video cache for VL model
self.video_cache = {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 self.video_cache = {} 初始化后在本文件中未见任何读写

video_cache 被初始化但在此 PR 所有变更中均未使用(video_features 直接通过 self.share_inputs 传递)。若此字段用于后续功能,建议添加注释说明预期用途;否则建议移除,避免混淆。

)
self.attn_mask_offsets_full = paddle.full(
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 非 speculative 路径缺少 attn_mask_offsets_decoder 初始化

swap_states(line 939)中在 self.enable_mm 条件下会访问 self.attn_mask_offsets_decoder,而当前新增的 init_share_inputs 多模态初始化分支(非 speculative)仅初始化了 attn_mask_offsets_fullattn_mask_offsets,缺少:

self.attn_mask_offsets_decoder = paddle.full([self.scheduler_config.max_num_seqs, 1], -1, dtype="int32")

如果非 speculative 场景下同样需要 swap_states,则会触发 AttributeError

)
swap_data(self.share_inputs["rope_emb"], i1, i2)
swap_data(self.decode_states, i1, i2)
swap_data(self.attn_mask_offsets_full, i1, i2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 swap_data 中新增了 decode_statesattn_mask_offsets_full 的 swap,但缺少 attn_mask_offsets 的 swap

swap_states(line 937)的逻辑相比,此处的 swap_data 块(init_share_inputs 路径)新增了 decode_statesattn_mask_offsets_full,但遗漏了 attn_mask_offsets。若两处 swap 逻辑不一致,可能导致 attn_mask_offsets 在某些 preemption 场景下数据错乱。建议补充:

swap_data(self.attn_mask_offsets, i1, i2)

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 43.03797% with 45 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.6@9ef8467). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/worker/gpu_model_runner.py 50.00% 19 Missing and 6 partials ⚠️
fastdeploy/worker/input_batch.py 31.03% 20 Missing ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.6    #7576   +/-   ##
==============================================
  Coverage               ?   73.76%           
==============================================
  Files                  ?      376           
  Lines                  ?    53169           
  Branches               ?     8315           
==============================================
  Hits                   ?    39219           
  Misses                 ?    11192           
  Partials               ?     2758           
Flag Coverage Δ
GPU 73.76% <43.03%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants