-
Notifications
You must be signed in to change notification settings - Fork 741
[Optimization][DeepSeekV3.2]Precompute the attention_mask_offset for Prefill in the Indexer #7598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,7 @@ | |
| AttentionMetadata, | ||
| ) | ||
| from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id | ||
| from fastdeploy.model_executor.ops.triton_ops import update_indexer_attn_mask_offsets | ||
|
|
||
|
|
||
| def yarn_get_mscale(scale=1, mscale=1): | ||
|
|
@@ -196,6 +197,16 @@ def quantize_k_cache( | |
| result = result.view(num_blocks, block_size, 1, -1) | ||
| return result | ||
|
|
||
| def _update_forward_meta(self, forward_meta: ForwardMeta): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 对于纯 Decode batch, 建议加入 from fastdeploy.model_executor.forward_meta import ForwardMode
def _update_forward_meta(self, forward_meta: ForwardMeta):
# Only needed during Prefill / Mixed stages
if forward_meta.forward_mode == ForwardMode.DECODE:
return
forward_meta.indexer_attn_mask_offsets = update_indexer_attn_mask_offsets(...) |
||
| """Update forward meta data.""" | ||
| # Indexer attn_mask_offset | ||
| forward_meta.indexer_attn_mask_offsets = update_indexer_attn_mask_offsets( | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| forward_meta.ids_remove_padding, | ||
| forward_meta.seq_lens_this_time, | ||
| forward_meta.seq_lens_encoder, | ||
| forward_meta.cu_seqlens_k, | ||
| ) | ||
|
|
||
| def init_attention_metadata(self, forward_meta: ForwardMeta): | ||
| """Initialize attention metadata hence all layers in the forward pass can reuse it.""" | ||
| metadata = DSAAttentionMetadata() | ||
|
|
@@ -256,6 +267,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): | |
| self.rank, int(self.device_id), self.keep_pd_step_flag | ||
| ) | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| self._update_forward_meta(forward_meta) | ||
| self.attention_metadata: AttentionMetadata = metadata | ||
|
|
||
| def get_attention_meta(self) -> AttentionMetadata: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -673,20 +673,10 @@ def forward( | |
| k_scale_cache_real = k_scale_cache.flatten()[: k.shape[0]].contiguous() | ||
| k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache_real | ||
|
|
||
| # TODO(changwenbin): Constructed using maskoffset | ||
| # ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous() | ||
| num_tokens = q_fp8.shape[0] | ||
| ks = paddle.zeros(num_tokens, dtype=paddle.int32) | ||
| ke = paddle.zeros(num_tokens, dtype=paddle.int32) | ||
|
|
||
| bsz = forward_meta.seq_lens_this_time.shape[0] | ||
| for i in range(bsz): | ||
| if forward_meta.seq_lens_encoder[i] > 0: | ||
| token_start_k = forward_meta.cu_seqlens_k[i] | ||
| token_end_k = forward_meta.cu_seqlens_k[i + 1] | ||
| ks[token_start_k:token_end_k] = forward_meta.cu_seqlens_k[i] | ||
| ke[token_start_k:token_end_k] = paddle.arange(token_start_k, token_end_k, dtype=paddle.int32) + 1 | ||
|
|
||
| # indexer_attn_mask_offsets is pre-computed by the Triton kernel | ||
| # update_indexer_attn_mask_offsets in dsa_attention_backend and stored in forward_meta. | ||
| ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous() | ||
This comment was marked as outdated.
Sorry, something went wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 若 是否可以添加断言保证前置条件? assert forward_meta.indexer_attn_mask_offsets is not None, \
"indexer_attn_mask_offsets must be precomputed by DSAAttentionBackend"
ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous() |
||
| ke = forward_meta.indexer_attn_mask_offsets[1::2].contiguous() | ||
| max_seqlen_k = (ke - ks).max().item() | ||
|
|
||
| logits = deep_gemm.fp8_mqa_logits( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| """ | ||
| # Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
|
|
||
| import paddle | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( | ||
| enable_compat_on_triton_kernel, | ||
| ) | ||
|
|
||
|
|
||
| @enable_compat_on_triton_kernel | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1.不必要的注释清理掉
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
| @triton.jit | ||
| def update_attn_mask_offsets_kernel( | ||
| seq_lens_this_time_ptr, | ||
| seq_lens_encoder_ptr, | ||
| cu_seqlens_k_ptr, | ||
| attn_mask_offsets_ptr, | ||
| BLOCK_M: tl.constexpr, | ||
| ): | ||
| """ | ||
| seq_lens_this_time: [bsz] | ||
| seq_lens_encoder: [bsz] | ||
| cu_seqlens_k: [bsz+1] | ||
| attn_mask_offsets: [num_tokens * 2], even indices = start, odd indices = end | ||
| """ | ||
| batch_id = tl.program_id(0) | ||
|
|
||
| seq_len_encoder = tl.load(seq_lens_encoder_ptr + batch_id) | ||
|
|
||
| # Skip decode requests (seq_lens_encoder == 0); offsets remain 0 | ||
| if seq_len_encoder <= 0: | ||
| return | ||
|
|
||
| seq_len_this_time = tl.load(seq_lens_this_time_ptr + batch_id) | ||
| token_start_k = tl.load(cu_seqlens_k_ptr + batch_id) # start offset in flattened token dim | ||
|
|
||
| for block_start in range(0, seq_len_this_time, BLOCK_M): | ||
| offsets = block_start + tl.arange(0, BLOCK_M) | ||
| mask = offsets < seq_len_this_time | ||
|
|
||
| global_token_idx = token_start_k + offsets # [BLOCK_M] | ||
|
|
||
| # causal window: start = batch k start (same for all tokens), end = current token + 1 | ||
| ks = tl.full((BLOCK_M,), token_start_k, dtype=tl.int32) | ||
| ke = (global_token_idx + 1).to(tl.int32) | ||
|
|
||
| tl.store(attn_mask_offsets_ptr + global_token_idx * 2, ks, mask=mask) | ||
| tl.store(attn_mask_offsets_ptr + global_token_idx * 2 + 1, ke, mask=mask) | ||
|
|
||
|
|
||
| def update_indexer_attn_mask_offsets( | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| ids_remove_padding, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 该参数在函数内仅作 建议修改: num_tokens = int(cu_seqlens_k[-1].item())并从函数签名中移除 |
||
| seq_lens_this_time, | ||
| seq_lens_encoder, | ||
| cu_seqlens_k, | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| ): | ||
| assert ids_remove_padding.ndim == 1 | ||
| assert seq_lens_this_time.ndim == 1 | ||
| assert seq_lens_encoder.ndim == 1 | ||
| assert cu_seqlens_k.ndim == 1 | ||
| num_tokens = ids_remove_padding.shape[0] | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| bsz = seq_lens_this_time.shape[0] | ||
| attention_mask_offset = paddle.zeros((num_tokens * 2), dtype=paddle.int32) | ||
|
|
||
| BLOCK_M = 128 | ||
| grid = (bsz,) | ||
|
|
||
| update_attn_mask_offsets_kernel[grid]( | ||
| seq_lens_this_time, | ||
| seq_lens_encoder, | ||
| cu_seqlens_k, | ||
| attention_mask_offset, | ||
| BLOCK_M, | ||
| ) | ||
| return attention_mask_offset | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.