Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class ForwardMeta:

# for mla & dsa
position_ids: Optional[paddle.Tensor] = None
indexer_attn_mask_offsets: Optional[paddle.Tensor] = None
# for kvcache slot
slot_mapping: Optional[paddle.Tensor] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.model_executor.ops.triton_ops import update_indexer_attn_mask_offsets


def yarn_get_mscale(scale=1, mscale=1):
Expand Down Expand Up @@ -196,6 +197,16 @@ def quantize_k_cache(
result = result.view(num_blocks, block_size, 1, -1)
return result

def _update_forward_meta(self, forward_meta: ForwardMeta):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _update_forward_meta 在每次 forward(含纯 Decode 阶段)时都会被调用,存在不必要的 GPU 内存分配开销。

对于纯 Decode batch,seq_lens_encoder 全为 0,kernel 内所有 block 都会直接 return,但 paddle.zeros((num_tokens * 2), ...) 的分配仍然发生。在高吞吐 Decode 场景下,这是每步都触发的冗余 GPU allocation。

建议加入 forward_mode 判断提前跳过:

from fastdeploy.model_executor.forward_meta import ForwardMode

def _update_forward_meta(self, forward_meta: ForwardMeta):
    # Only needed during Prefill / Mixed stages
    if forward_meta.forward_mode == ForwardMode.DECODE:
        return
    forward_meta.indexer_attn_mask_offsets = update_indexer_attn_mask_offsets(...)

"""Update forward meta data."""
# Indexer attn_mask_offset
forward_meta.indexer_attn_mask_offsets = update_indexer_attn_mask_offsets(

This comment was marked as outdated.

forward_meta.ids_remove_padding,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.cu_seqlens_k,
)

def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
metadata = DSAAttentionMetadata()
Expand Down Expand Up @@ -256,6 +267,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
self.rank, int(self.device_id), self.keep_pd_step_flag
)

This comment was marked as outdated.

self._update_forward_meta(forward_meta)
self.attention_metadata: AttentionMetadata = metadata

def get_attention_meta(self) -> AttentionMetadata:
Expand Down
18 changes: 4 additions & 14 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,20 +673,10 @@ def forward(
k_scale_cache_real = k_scale_cache.flatten()[: k.shape[0]].contiguous()
k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache_real

# TODO(changwenbin): Constructed using maskoffset
# ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous()
num_tokens = q_fp8.shape[0]
ks = paddle.zeros(num_tokens, dtype=paddle.int32)
ke = paddle.zeros(num_tokens, dtype=paddle.int32)

bsz = forward_meta.seq_lens_this_time.shape[0]
for i in range(bsz):
if forward_meta.seq_lens_encoder[i] > 0:
token_start_k = forward_meta.cu_seqlens_k[i]
token_end_k = forward_meta.cu_seqlens_k[i + 1]
ks[token_start_k:token_end_k] = forward_meta.cu_seqlens_k[i]
ke[token_start_k:token_end_k] = paddle.arange(token_start_k, token_end_k, dtype=paddle.int32) + 1

# indexer_attn_mask_offsets is pre-computed by the Triton kernel
# update_indexer_attn_mask_offsets in dsa_attention_backend and stored in forward_meta.
ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous()

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 forward_meta.indexer_attn_mask_offsets 字段类型为 Optional[paddle.Tensor],此处直接切片访问未做 None 检查。

init_attention_metadata 未被调用(例如单元测试直接构造 ForwardMeta、或 Triton 不可用导致 _update_forward_meta 抛出 ImportError 而被跳过),该字段将保持 None,切片 [::2] 会抛出 TypeError

是否可以添加断言保证前置条件?

assert forward_meta.indexer_attn_mask_offsets is not None, \
    "indexer_attn_mask_offsets must be precomputed by DSAAttentionBackend"
ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous()

ke = forward_meta.indexer_attn_mask_offsets[1::2].contiguous()
max_seqlen_k = (ke - ks).max().item()

logits = deep_gemm.fp8_mqa_logits(
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/ops/triton_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

try:
from .indexer_update_attn_mask_offsets import update_indexer_attn_mask_offsets
from .pre_token_quant_fp8_kernel import _per_token_group_quant_fp8
from .qk_rmsnorm_fused_kernel import qk_rmsnorm_fused
from .repetition_early_stop_kernel import repetition_early_stopper_kernel
Expand All @@ -27,6 +28,7 @@
"repetition_early_stopper_kernel",
"qk_rmsnorm_fused",
"_per_token_group_quant_fp8",
"update_indexer_attn_mask_offsets",
]
except:
_TRITON_AVAILABLE = False
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import paddle
import triton
import triton.language as tl

from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
enable_compat_on_triton_kernel,
)


@enable_compat_on_triton_kernel
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.不必要的注释清理掉
2. 改成英文的

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@triton.jit
def update_attn_mask_offsets_kernel(
seq_lens_this_time_ptr,
seq_lens_encoder_ptr,
cu_seqlens_k_ptr,
attn_mask_offsets_ptr,
BLOCK_M: tl.constexpr,
):
"""
seq_lens_this_time: [bsz]
seq_lens_encoder: [bsz]
cu_seqlens_k: [bsz+1]
attn_mask_offsets: [num_tokens * 2], even indices = start, odd indices = end
"""
batch_id = tl.program_id(0)

seq_len_encoder = tl.load(seq_lens_encoder_ptr + batch_id)

# Skip decode requests (seq_lens_encoder == 0); offsets remain 0
if seq_len_encoder <= 0:
return

seq_len_this_time = tl.load(seq_lens_this_time_ptr + batch_id)
token_start_k = tl.load(cu_seqlens_k_ptr + batch_id) # start offset in flattened token dim

for block_start in range(0, seq_len_this_time, BLOCK_M):
offsets = block_start + tl.arange(0, BLOCK_M)
mask = offsets < seq_len_this_time

global_token_idx = token_start_k + offsets # [BLOCK_M]

# causal window: start = batch k start (same for all tokens), end = current token + 1
ks = tl.full((BLOCK_M,), token_start_k, dtype=tl.int32)
ke = (global_token_idx + 1).to(tl.int32)

tl.store(attn_mask_offsets_ptr + global_token_idx * 2, ks, mask=mask)
tl.store(attn_mask_offsets_ptr + global_token_idx * 2 + 1, ke, mask=mask)


def update_indexer_attn_mask_offsets(

This comment was marked as outdated.

ids_remove_padding,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 ids_remove_padding 参数仅用于获取 num_tokens,实际数据未传入 kernel。

该参数在函数内仅作 num_tokens = ids_remove_padding.shape[0],而 cu_seqlens_k[-1] 同样等于 sum(seq_lens_this_time),可以直接用来代替,无需传入整个 Tensor。当前接口容易让调用方误以为 kernel 会读取 token ID 数据,存在语义误导。

建议修改:

num_tokens = int(cu_seqlens_k[-1].item())

并从函数签名中移除 ids_remove_padding 参数,同步更新 dsa_attention_backend.py 的调用处。

seq_lens_this_time,
seq_lens_encoder,
cu_seqlens_k,

This comment was marked as outdated.

):
assert ids_remove_padding.ndim == 1
assert seq_lens_this_time.ndim == 1
assert seq_lens_encoder.ndim == 1
assert cu_seqlens_k.ndim == 1
num_tokens = ids_remove_padding.shape[0]

This comment was marked as outdated.

bsz = seq_lens_this_time.shape[0]
attention_mask_offset = paddle.zeros((num_tokens * 2), dtype=paddle.int32)

BLOCK_M = 128
grid = (bsz,)

update_attn_mask_offsets_kernel[grid](
seq_lens_this_time,
seq_lens_encoder,
cu_seqlens_k,
attention_mask_offset,
BLOCK_M,
)
return attention_mask_offset
1 change: 0 additions & 1 deletion fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,6 @@ def _initialize_attn_backend(self) -> None:
encoder_block_shape_q=encoder_block_shape_q,
decoder_block_shape_q=decoder_block_shape_q,
)

self.attn_backends.append(attn_backend)

def _dummy_pooler_run_task(
Expand Down
Loading
Loading