-
Notifications
You must be signed in to change notification settings - Fork 742
[Optimization][DeepSeekV3.2]Precompute the attention_mask_offset for Prefill in the Indexer #7598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,7 @@ | |
| AttentionMetadata, | ||
| ) | ||
| from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id | ||
| from fastdeploy.model_executor.ops.triton_ops import update_indexer_attn_mask_offsets | ||
|
|
||
|
|
||
| def yarn_get_mscale(scale=1, mscale=1): | ||
|
|
@@ -196,6 +197,16 @@ def quantize_k_cache( | |
| result = result.view(num_blocks, block_size, 1, -1) | ||
| return result | ||
|
|
||
| def _update_forward_meta(self, forward_meta: ForwardMeta): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 对于纯 Decode batch, 建议加入 from fastdeploy.model_executor.forward_meta import ForwardMode
def _update_forward_meta(self, forward_meta: ForwardMeta):
# Only needed during Prefill / Mixed stages
if forward_meta.forward_mode == ForwardMode.DECODE:
return
forward_meta.indexer_attn_mask_offsets = update_indexer_attn_mask_offsets(...) |
||
| """Update forward meta data.""" | ||
| # Indexer attn_mask_offset | ||
| forward_meta.indexer_attn_mask_offsets = update_indexer_attn_mask_offsets( | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| forward_meta.ids_remove_padding, | ||
| forward_meta.seq_lens_this_time, | ||
| forward_meta.seq_lens_encoder, | ||
| forward_meta.cu_seqlens_k, | ||
| ) | ||
|
|
||
| def init_attention_metadata(self, forward_meta: ForwardMeta): | ||
| """Initialize attention metadata hence all layers in the forward pass can reuse it.""" | ||
| metadata = DSAAttentionMetadata() | ||
|
|
@@ -256,6 +267,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): | |
| self.rank, int(self.device_id), self.keep_pd_step_flag | ||
| ) | ||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| self._update_forward_meta(forward_meta) | ||
| self.attention_metadata: AttentionMetadata = metadata | ||
|
|
||
| def get_attention_meta(self) -> AttentionMetadata: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,6 +70,8 @@ | |
| radix_topk_ragged_transform, | ||
| ) | ||
|
|
||
| paddle.enable_compat(scope={"deep_gemm": True}) | ||
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong. |
||
|
|
||
|
|
||
| class DeepSeekV3MLP(nn.Layer): | ||
| """ | ||
|
|
@@ -659,7 +661,8 @@ def forward( | |
| k, self.indexer_cache, forward_meta.slot_mapping, self.quant_block_size, self.scale_fmt | ||
| ) | ||
|
|
||
| from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm | ||
| # from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm | ||
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong. |
||
| import deep_gemm | ||
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong. |
||
|
|
||
| if forward_meta.max_len_tensor_cpu[1]: | ||
|
|
||
|
|
@@ -673,20 +676,10 @@ def forward( | |
| k_scale_cache_real = k_scale_cache.flatten()[: k.shape[0]].contiguous() | ||
| k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache_real | ||
|
|
||
| # TODO(changwenbin): Constructed using maskoffset | ||
| # ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous() | ||
| num_tokens = q_fp8.shape[0] | ||
| ks = paddle.zeros(num_tokens, dtype=paddle.int32) | ||
| ke = paddle.zeros(num_tokens, dtype=paddle.int32) | ||
|
|
||
| bsz = forward_meta.seq_lens_this_time.shape[0] | ||
| for i in range(bsz): | ||
| if forward_meta.seq_lens_encoder[i] > 0: | ||
| token_start_k = forward_meta.cu_seqlens_k[i] | ||
| token_end_k = forward_meta.cu_seqlens_k[i + 1] | ||
| ks[token_start_k:token_end_k] = forward_meta.cu_seqlens_k[i] | ||
| ke[token_start_k:token_end_k] = paddle.arange(token_start_k, token_end_k, dtype=paddle.int32) + 1 | ||
|
|
||
| # indexer_attn_mask_offsets is pre-computed by the Triton kernel | ||
| # update_indexer_attn_mask_offsets in dsa_attention_backend and stored in forward_meta. | ||
| ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous() | ||
This comment was marked as outdated.
Sorry, something went wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 若 是否可以添加断言保证前置条件? assert forward_meta.indexer_attn_mask_offsets is not None, \
"indexer_attn_mask_offsets must be precomputed by DSAAttentionBackend"
ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous() |
||
| ke = forward_meta.indexer_attn_mask_offsets[1::2].contiguous() | ||
| max_seqlen_k = (ke - ks).max().item() | ||
|
|
||
| logits = deep_gemm.fp8_mqa_logits( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| """ | ||
| # Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
|
|
||
| import paddle | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( | ||
| enable_compat_on_triton_kernel, | ||
| ) | ||
|
|
||
|
|
||
| @enable_compat_on_triton_kernel | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1.不必要的注释清理掉
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
| @triton.jit | ||
| def update_attn_mask_offsets_kernel( | ||
| seq_lens_this_time_ptr, | ||
| seq_lens_encoder_ptr, | ||
| cu_seqlens_k_ptr, | ||
| attn_mask_offsets_ptr, | ||
| BLOCK_M: tl.constexpr, | ||
| ): | ||
| """ | ||
| seq_lens_this_time: [bsz] | ||
| seq_lens_encoder: [bsz] | ||
| cu_seqlens_k: [bsz+1] | ||
| attn_mask_offsets: [num_tokens * 2] | ||
| - 偶数位 = start | ||
| - 奇数位 = end | ||
| """ | ||
| batch_id = tl.program_id(0) | ||
|
|
||
| seq_len_encoder = tl.load(seq_lens_encoder_ptr + batch_id) | ||
|
|
||
| # decode 请求(seq_lens_encoder == 0)的 token 在 Indexer 中不走 prefill 路径 | ||
| # attn_mask_offsets 对应位置保持 0,无需写入,直接跳过 | ||
| if seq_len_encoder <= 0: | ||
| return | ||
|
|
||
| seq_len_this_time = tl.load(seq_lens_this_time_ptr + batch_id) | ||
| token_start_k = tl.load(cu_seqlens_k_ptr + batch_id) # 本 batch 在 flatten token 维度的起始偏移 | ||
|
|
||
| # 每个 block 负责一个 batch,内部用 BLOCK_M 分块遍历该 batch 的所有 token | ||
| for block_start in range(0, seq_len_this_time, BLOCK_M): | ||
| offsets = block_start + tl.arange(0, BLOCK_M) # 相对于本 batch 起始的 token 局部偏移 | ||
| mask = offsets < seq_len_this_time | ||
|
|
||
| # 在 flatten token 维度中的全局 token 索引 | ||
| global_token_idx = token_start_k + offsets # [BLOCK_M] | ||
|
|
||
| # start:causal 窗口左边界 = 本 batch k 序列的起始位置(所有 token 相同) | ||
| ks = tl.full((BLOCK_M,), token_start_k, dtype=tl.int32) | ||
|
|
||
| # end:causal 窗口右边界 = 当前 token 的全局索引 + 1(只能 attend 到自身及之前) | ||
| ke = (global_token_idx + 1).to(tl.int32) | ||
|
|
||
| # 写入 attn_mask_offsets,偶数位存 start,奇数位存 end | ||
| tl.store(attn_mask_offsets_ptr + global_token_idx * 2, ks, mask=mask) | ||
| tl.store(attn_mask_offsets_ptr + global_token_idx * 2 + 1, ke, mask=mask) | ||
|
|
||
|
|
||
| def update_indexer_attn_mask_offsets( | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| ids_remove_padding, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 该参数在函数内仅作 建议修改: num_tokens = int(cu_seqlens_k[-1].item())并从函数签名中移除 |
||
| seq_lens_this_time, | ||
| seq_lens_encoder, | ||
| cu_seqlens_k, | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| ): | ||
| assert ids_remove_padding.ndim == 1 | ||
| assert seq_lens_this_time.ndim == 1 | ||
| assert seq_lens_encoder.ndim == 1 | ||
| assert cu_seqlens_k.ndim == 1 | ||
| num_tokens = ids_remove_padding.shape[0] | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| bsz = seq_lens_this_time.shape[0] | ||
| attention_mask_offset = paddle.zeros((num_tokens * 2), dtype=paddle.int32) | ||
|
|
||
| # 每个 batch 对应一个 Triton program,BLOCK_SIZE 为每个 program 内处理 token 的粒度 | ||
| BLOCK_SIZE = 128 | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| grid = (bsz,) | ||
|
|
||
| update_attn_mask_offsets_kernel[grid]( | ||
| seq_lens_this_time, | ||
| seq_lens_encoder, | ||
| cu_seqlens_k, | ||
| attention_mask_offset, | ||
| BLOCK_M=BLOCK_SIZE, | ||
| ) | ||
| return attention_mask_offset | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.