-
Notifications
You must be signed in to change notification settings - Fork 742
[Optimization][DeepSeekV3.2]Precompute the attention_mask_offset for Prefill in the Indexer #7598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -673,20 +673,10 @@ def forward( | |
| k_scale_cache_real = k_scale_cache.flatten()[: k.shape[0]].contiguous() | ||
| k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache_real | ||
|
|
||
| # TODO(changwenbin): Constructed using maskoffset | ||
| # ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous() | ||
| num_tokens = q_fp8.shape[0] | ||
| ks = paddle.zeros(num_tokens, dtype=paddle.int32) | ||
| ke = paddle.zeros(num_tokens, dtype=paddle.int32) | ||
|
|
||
| bsz = forward_meta.seq_lens_this_time.shape[0] | ||
| for i in range(bsz): | ||
| if forward_meta.seq_lens_encoder[i] > 0: | ||
| token_start_k = forward_meta.cu_seqlens_k[i] | ||
| token_end_k = forward_meta.cu_seqlens_k[i + 1] | ||
| ks[token_start_k:token_end_k] = forward_meta.cu_seqlens_k[i] | ||
| ke[token_start_k:token_end_k] = paddle.arange(token_start_k, token_end_k, dtype=paddle.int32) + 1 | ||
|
|
||
| # indexer_attn_mask_offsets is pre-computed by the Triton kernel | ||
| # update_indexer_attn_mask_offsets in dsa_attention_backend and stored in forward_meta. | ||
| ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous() | ||
This comment was marked as outdated.
Sorry, something went wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 若 是否可以添加断言保证前置条件? assert forward_meta.indexer_attn_mask_offsets is not None, \
"indexer_attn_mask_offsets must be precomputed by DSAAttentionBackend"
ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous() |
||
| ke = forward_meta.indexer_attn_mask_offsets[1::2].contiguous() | ||
| max_seqlen_k = (ke - ks).max().item() | ||
|
|
||
| logits = deep_gemm.fp8_mqa_logits( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| """ | ||
| # Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
|
|
||
| import paddle | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( | ||
| enable_compat_on_triton_kernel, | ||
| ) | ||
|
|
||
|
|
||
| @enable_compat_on_triton_kernel | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1.不必要的注释清理掉
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
| @triton.jit | ||
| def update_attn_mask_offsets_kernel( | ||
| seq_lens_this_time_ptr, | ||
| seq_lens_encoder_ptr, | ||
| cu_seqlens_k_ptr, | ||
| attn_mask_offsets_ptr, | ||
| BLOCK_M: tl.constexpr, | ||
| ): | ||
| """ | ||
| seq_lens_this_time: [bsz] | ||
| seq_lens_encoder: [bsz] | ||
| cu_seqlens_k: [bsz+1] | ||
| attn_mask_offsets: [num_tokens * 2], even indices = start, odd indices = end | ||
| """ | ||
| batch_id = tl.program_id(0) | ||
|
|
||
| seq_len_encoder = tl.load(seq_lens_encoder_ptr + batch_id) | ||
|
|
||
| # Skip decode requests (seq_lens_encoder == 0); offsets remain 0 | ||
| if seq_len_encoder <= 0: | ||
| return | ||
|
|
||
| seq_len_this_time = tl.load(seq_lens_this_time_ptr + batch_id) | ||
| token_start_k = tl.load(cu_seqlens_k_ptr + batch_id) # start offset in flattened token dim | ||
|
|
||
| for block_start in range(0, seq_len_this_time, BLOCK_M): | ||
| offsets = block_start + tl.arange(0, BLOCK_M) | ||
| mask = offsets < seq_len_this_time | ||
|
|
||
| global_token_idx = token_start_k + offsets # [BLOCK_M] | ||
|
|
||
| # causal window: start = batch k start (same for all tokens), end = current token + 1 | ||
| ks = tl.full((BLOCK_M,), token_start_k, dtype=tl.int32) | ||
| ke = (global_token_idx + 1).to(tl.int32) | ||
|
|
||
| tl.store(attn_mask_offsets_ptr + global_token_idx * 2, ks, mask=mask) | ||
| tl.store(attn_mask_offsets_ptr + global_token_idx * 2 + 1, ke, mask=mask) | ||
|
|
||
|
|
||
| def update_indexer_attn_mask_offsets( | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| ids_remove_padding, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 该参数在函数内仅作 建议修改: num_tokens = int(cu_seqlens_k[-1].item())并从函数签名中移除 |
||
| seq_lens_this_time, | ||
| seq_lens_encoder, | ||
| cu_seqlens_k, | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| ): | ||
| assert ids_remove_padding.ndim == 1 | ||
| assert seq_lens_this_time.ndim == 1 | ||
| assert seq_lens_encoder.ndim == 1 | ||
| assert cu_seqlens_k.ndim == 1 | ||
| num_tokens = ids_remove_padding.shape[0] | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| bsz = seq_lens_this_time.shape[0] | ||
| attention_mask_offset = paddle.zeros((num_tokens * 2), dtype=paddle.int32) | ||
|
|
||
| BLOCK_M = 128 | ||
| grid = (bsz,) | ||
|
|
||
| update_attn_mask_offsets_kernel[grid]( | ||
| seq_lens_this_time, | ||
| seq_lens_encoder, | ||
| cu_seqlens_k, | ||
| attention_mask_offset, | ||
| BLOCK_M, | ||
| ) | ||
| return attention_mask_offset | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议
_update_forward_meta在每次 forward(含纯 Decode 阶段)时都会被调用,存在不必要的 GPU 内存分配开销。对于纯 Decode batch,
seq_lens_encoder全为 0,kernel 内所有 block 都会直接 return,但paddle.zeros((num_tokens * 2), ...)的分配仍然发生。在高吞吐 Decode 场景下,这是每步都触发的冗余 GPU allocation。建议加入
forward_mode判断提前跳过: