Skip to content

[Optimization][DeepSeekV3.2]Precompute the attention_mask_offset for Prefill in the Indexer#7598

Open
ShaneGZhu wants to merge 6 commits intoPaddlePaddle:developfrom
ShaneGZhu:indexer_dev
Open

[Optimization][DeepSeekV3.2]Precompute the attention_mask_offset for Prefill in the Indexer#7598
ShaneGZhu wants to merge 6 commits intoPaddlePaddle:developfrom
ShaneGZhu:indexer_dev

Conversation

@ShaneGZhu
Copy link
Copy Markdown
Contributor

@ShaneGZhu ShaneGZhu commented Apr 23, 2026

Motivation

In DeepSeekV3.2's Indexer Prefill stage, the attention_mask_offset (i.e., the causal attention window [ks, ke) for each token) was previously computed via Python loops on CPU — iterating batch by batch and token by token. This becomes a bottleneck on the critical forward path when handling large batch sizes or long sequences.

This PR precomputes attention_mask_offset using a Triton kernel on GPU during init_attention_metadata, eliminating the CPU-side Python loop overhead and improving Prefill throughput for the Indexer attention backend.

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

  • fastdeploy/model_executor/ops/triton_ops/indexer_update_attn_mask_offsets.py (new file)
    Implements the Triton kernel update_indexer_attn_mask_offsets that batch-computes the causal attention window [ks, ke) for all prefill tokens in a single GPU kernel launch. Also provides a Python reference implementation ref_update_attn_mask_offsets for correctness verification.
  • fastdeploy/model_executor/layers/attention/dsa_attention_backend.py
    Adds a _update_forward_meta method that calls update_indexer_attn_mask_offsets to precompute attention_mask_offset and stores the result into forward_meta.indexer_attn_mask_offsets.
  • fastdeploy/model_executor/forward_meta.py
    Adds indexer_attn_mask_offsets field to ForwardMeta to carry the precomputed offsets.

Usage or Command

No API changes. The optimization is transparent to users. To verify correctness, run:

python -m pytest tests/model_executor/ops/triton_ops/test_indexer_update_attn_mask_offsets.py -v

Accuracy Tests

The new Triton kernel is validated against the Python reference implementation ref_update_attn_mask_offsets in the unit tests, covering:

  • Edge cases (single token, empty sequences)
  • Mixed-batch scenarios with varying sequence lengths
    No changes to model forward logic or kernel math — accuracy of model outputs is unaffected.

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 23, 2026

Thanks for your contribution!

PaddlePaddle-bot

This comment was marked as outdated.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 23, 2026

Codecov Report

❌ Patch coverage is 56.09756% with 18 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@a6a740f). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...ops/triton_ops/indexer_update_attn_mask_offsets.py 57.57% 14 Missing ⚠️
...executor/layers/attention/dsa_attention_backend.py 50.00% 2 Missing ⚠️
fastdeploy/model_executor/models/deepseek_v3.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7598   +/-   ##
==========================================
  Coverage           ?   71.68%           
==========================================
  Files              ?      420           
  Lines              ?    57885           
  Branches           ?     9077           
==========================================
  Hits               ?    41495           
  Misses             ?    13561           
  Partials           ?     2829           
Flag Coverage Δ
GPU 71.68% <56.09%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ShaneGZhu ShaneGZhu changed the title [Optimize][DeepSeekV3.2]Precompute the attention_mask_offset for Prefill in the Indexer [Optimization][DeepSeekV3.2]Precompute the attention_mask_offset for Prefill in the Indexer Apr 24, 2026
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

@gongshaotian gongshaotian marked this pull request as ready for review April 24, 2026 11:34
)


@enable_compat_on_triton_kernel
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.不必要的注释清理掉
2. 改成英文的

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-24 22:33:33

📋 Review 摘要

PR 概述:将 DeepSeekV3.2 Indexer Prefill 阶段的 attention_mask_offset 计算从 CPU Python 循环迁移到 Triton GPU kernel,消除热路径上的 Python 循环开销。
变更范围model_executor/ops/triton_opsmodel_executor/layers/attentionmodel_executor/models/deepseek_v3.pyforward_meta.py
影响面 TagModels OP

问题

级别 文件 概述
🟡 建议 indexer_update_attn_mask_offsets.py:67 ids_remove_padding 参数冗余,仅用了 .shape[0],接口存在语义误导
🟡 建议 dsa_attention_backend.py:200 _update_forward_meta 每次 forward 均触发 GPU allocation,纯 Decode 阶段存在不必要开销
❓ 疑问 deepseek_v3.py:678 直接切片访问 Optional 字段,缺乏 None 守卫

总体评价

优化思路清晰,Triton kernel 实现正确,单元测试覆盖充分(含边界场景和压力测试)。主要建议:清理冗余接口参数、加 forward_mode guard 避免 Decode 阶段的无效 allocation,以及在 deepseek_v3.py 消费侧补充 None 断言。



def update_indexer_attn_mask_offsets(
ids_remove_padding,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 ids_remove_padding 参数仅用于获取 num_tokens,实际数据未传入 kernel。

该参数在函数内仅作 num_tokens = ids_remove_padding.shape[0],而 cu_seqlens_k[-1] 同样等于 sum(seq_lens_this_time),可以直接用来代替,无需传入整个 Tensor。当前接口容易让调用方误以为 kernel 会读取 token ID 数据,存在语义误导。

建议修改:

num_tokens = int(cu_seqlens_k[-1].item())

并从函数签名中移除 ids_remove_padding 参数,同步更新 dsa_attention_backend.py 的调用处。

result = result.view(num_blocks, block_size, 1, -1)
return result

def _update_forward_meta(self, forward_meta: ForwardMeta):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _update_forward_meta 在每次 forward(含纯 Decode 阶段)时都会被调用,存在不必要的 GPU 内存分配开销。

对于纯 Decode batch,seq_lens_encoder 全为 0,kernel 内所有 block 都会直接 return,但 paddle.zeros((num_tokens * 2), ...) 的分配仍然发生。在高吞吐 Decode 场景下,这是每步都触发的冗余 GPU allocation。

建议加入 forward_mode 判断提前跳过:

from fastdeploy.model_executor.forward_meta import ForwardMode

def _update_forward_meta(self, forward_meta: ForwardMeta):
    # Only needed during Prefill / Mixed stages
    if forward_meta.forward_mode == ForwardMode.DECODE:
        return
    forward_meta.indexer_attn_mask_offsets = update_indexer_attn_mask_offsets(...)


# indexer_attn_mask_offsets is pre-computed by the Triton kernel
# update_indexer_attn_mask_offsets in dsa_attention_backend and stored in forward_meta.
ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 forward_meta.indexer_attn_mask_offsets 字段类型为 Optional[paddle.Tensor],此处直接切片访问未做 None 检查。

init_attention_metadata 未被调用(例如单元测试直接构造 ForwardMeta、或 Triton 不可用导致 _update_forward_meta 抛出 ImportError 而被跳过),该字段将保持 None,切片 [::2] 会抛出 TypeError

是否可以添加断言保证前置条件?

assert forward_meta.indexer_attn_mask_offsets is not None, \
    "indexer_attn_mask_offsets must be precomputed by DSAAttentionBackend"
ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants