Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class ForwardMeta:

# for mla & dsa
position_ids: Optional[paddle.Tensor] = None
indexer_attn_mask_offsets: Optional[paddle.Tensor] = None
# for kvcache slot
slot_mapping: Optional[paddle.Tensor] = None
Comment on lines 165 to 169
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. indexer_attn_mask_offsets 是必须放在forward meta 里的吗?放attn meta data 里有啥问题
  2. 这些注释都不规范,解释下语意


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.model_executor.ops.triton_ops import update_indexer_attn_mask_offsets


def yarn_get_mscale(scale=1, mscale=1):
Expand Down Expand Up @@ -196,6 +197,16 @@ def quantize_k_cache(
result = result.view(num_blocks, block_size, 1, -1)
return result

def _update_forward_meta(self, forward_meta: ForwardMeta):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _update_forward_meta 在每次 forward(含纯 Decode 阶段)时都会被调用,存在不必要的 GPU 内存分配开销。

对于纯 Decode batch,seq_lens_encoder 全为 0,kernel 内所有 block 都会直接 return,但 paddle.zeros((num_tokens * 2), ...) 的分配仍然发生。在高吞吐 Decode 场景下,这是每步都触发的冗余 GPU allocation。

建议加入 forward_mode 判断提前跳过:

from fastdeploy.model_executor.forward_meta import ForwardMode

def _update_forward_meta(self, forward_meta: ForwardMeta):
    # Only needed during Prefill / Mixed stages
    if forward_meta.forward_mode == ForwardMode.DECODE:
        return
    forward_meta.indexer_attn_mask_offsets = update_indexer_attn_mask_offsets(...)

"""Update forward meta data."""
# Indexer attn_mask_offset
forward_meta.indexer_attn_mask_offsets = update_indexer_attn_mask_offsets(

This comment was marked as outdated.

forward_meta.ids_remove_padding,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.cu_seqlens_k,
)

def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
metadata = DSAAttentionMetadata()
Expand Down Expand Up @@ -256,6 +267,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
self.rank, int(self.device_id), self.keep_pd_step_flag
)

This comment was marked as outdated.

self._update_forward_meta(forward_meta)
self.attention_metadata: AttentionMetadata = metadata

def get_attention_meta(self) -> AttentionMetadata:
Expand Down
23 changes: 8 additions & 15 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
radix_topk_ragged_transform,
)

paddle.enable_compat(scope={"deep_gemm": True})

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.



class DeepSeekV3MLP(nn.Layer):
"""
Expand Down Expand Up @@ -659,7 +661,8 @@ def forward(
k, self.indexer_cache, forward_meta.slot_mapping, self.quant_block_size, self.scale_fmt
)

from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm
# from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm

This comment was marked as outdated.

This comment was marked as outdated.

import deep_gemm

This comment was marked as outdated.

This comment was marked as outdated.


if forward_meta.max_len_tensor_cpu[1]:

Expand All @@ -673,20 +676,10 @@ def forward(
k_scale_cache_real = k_scale_cache.flatten()[: k.shape[0]].contiguous()
k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache_real

# TODO(changwenbin): Constructed using maskoffset
# ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous()
num_tokens = q_fp8.shape[0]
ks = paddle.zeros(num_tokens, dtype=paddle.int32)
ke = paddle.zeros(num_tokens, dtype=paddle.int32)

bsz = forward_meta.seq_lens_this_time.shape[0]
for i in range(bsz):
if forward_meta.seq_lens_encoder[i] > 0:
token_start_k = forward_meta.cu_seqlens_k[i]
token_end_k = forward_meta.cu_seqlens_k[i + 1]
ks[token_start_k:token_end_k] = forward_meta.cu_seqlens_k[i]
ke[token_start_k:token_end_k] = paddle.arange(token_start_k, token_end_k, dtype=paddle.int32) + 1

# indexer_attn_mask_offsets is pre-computed by the Triton kernel
# update_indexer_attn_mask_offsets in dsa_attention_backend and stored in forward_meta.
ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous()

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 forward_meta.indexer_attn_mask_offsets 字段类型为 Optional[paddle.Tensor],此处直接切片访问未做 None 检查。

init_attention_metadata 未被调用(例如单元测试直接构造 ForwardMeta、或 Triton 不可用导致 _update_forward_meta 抛出 ImportError 而被跳过),该字段将保持 None,切片 [::2] 会抛出 TypeError

是否可以添加断言保证前置条件?

assert forward_meta.indexer_attn_mask_offsets is not None, \
    "indexer_attn_mask_offsets must be precomputed by DSAAttentionBackend"
ks = forward_meta.indexer_attn_mask_offsets[::2].contiguous()

ke = forward_meta.indexer_attn_mask_offsets[1::2].contiguous()
max_seqlen_k = (ke - ks).max().item()

logits = deep_gemm.fp8_mqa_logits(
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/ops/triton_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

try:
from .indexer_update_attn_mask_offsets import update_indexer_attn_mask_offsets
from .pre_token_quant_fp8_kernel import _per_token_group_quant_fp8
from .qk_rmsnorm_fused_kernel import qk_rmsnorm_fused
from .repetition_early_stop_kernel import repetition_early_stopper_kernel
Expand All @@ -27,6 +28,7 @@
"repetition_early_stopper_kernel",
"qk_rmsnorm_fused",
"_per_token_group_quant_fp8",
"update_indexer_attn_mask_offsets",
]
except:
_TRITON_AVAILABLE = False
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import paddle
import triton
import triton.language as tl

from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
enable_compat_on_triton_kernel,
)


@enable_compat_on_triton_kernel
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.不必要的注释清理掉
2. 改成英文的

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@triton.jit
def update_attn_mask_offsets_kernel(
seq_lens_this_time_ptr,
seq_lens_encoder_ptr,
cu_seqlens_k_ptr,
attn_mask_offsets_ptr,
BLOCK_M: tl.constexpr,
):
"""
seq_lens_this_time: [bsz]
seq_lens_encoder: [bsz]
cu_seqlens_k: [bsz+1]
attn_mask_offsets: [num_tokens * 2]
- 偶数位 = start
- 奇数位 = end
"""
batch_id = tl.program_id(0)

seq_len_encoder = tl.load(seq_lens_encoder_ptr + batch_id)

# decode 请求(seq_lens_encoder == 0)的 token 在 Indexer 中不走 prefill 路径
# attn_mask_offsets 对应位置保持 0,无需写入,直接跳过
if seq_len_encoder <= 0:
return

seq_len_this_time = tl.load(seq_lens_this_time_ptr + batch_id)
token_start_k = tl.load(cu_seqlens_k_ptr + batch_id) # 本 batch 在 flatten token 维度的起始偏移

# 每个 block 负责一个 batch,内部用 BLOCK_M 分块遍历该 batch 的所有 token
for block_start in range(0, seq_len_this_time, BLOCK_M):
offsets = block_start + tl.arange(0, BLOCK_M) # 相对于本 batch 起始的 token 局部偏移
mask = offsets < seq_len_this_time

# 在 flatten token 维度中的全局 token 索引
global_token_idx = token_start_k + offsets # [BLOCK_M]

# start:causal 窗口左边界 = 本 batch k 序列的起始位置(所有 token 相同)
ks = tl.full((BLOCK_M,), token_start_k, dtype=tl.int32)

# end:causal 窗口右边界 = 当前 token 的全局索引 + 1(只能 attend 到自身及之前)
ke = (global_token_idx + 1).to(tl.int32)

# 写入 attn_mask_offsets,偶数位存 start,奇数位存 end
tl.store(attn_mask_offsets_ptr + global_token_idx * 2, ks, mask=mask)
tl.store(attn_mask_offsets_ptr + global_token_idx * 2 + 1, ke, mask=mask)


def update_indexer_attn_mask_offsets(

This comment was marked as outdated.

ids_remove_padding,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 ids_remove_padding 参数仅用于获取 num_tokens,实际数据未传入 kernel。

该参数在函数内仅作 num_tokens = ids_remove_padding.shape[0],而 cu_seqlens_k[-1] 同样等于 sum(seq_lens_this_time),可以直接用来代替,无需传入整个 Tensor。当前接口容易让调用方误以为 kernel 会读取 token ID 数据,存在语义误导。

建议修改:

num_tokens = int(cu_seqlens_k[-1].item())

并从函数签名中移除 ids_remove_padding 参数,同步更新 dsa_attention_backend.py 的调用处。

seq_lens_this_time,
seq_lens_encoder,
cu_seqlens_k,

This comment was marked as outdated.

):
assert ids_remove_padding.ndim == 1
assert seq_lens_this_time.ndim == 1
assert seq_lens_encoder.ndim == 1
assert cu_seqlens_k.ndim == 1
num_tokens = ids_remove_padding.shape[0]

This comment was marked as outdated.

bsz = seq_lens_this_time.shape[0]
attention_mask_offset = paddle.zeros((num_tokens * 2), dtype=paddle.int32)

# 每个 batch 对应一个 Triton program,BLOCK_SIZE 为每个 program 内处理 token 的粒度
BLOCK_SIZE = 128

This comment was marked as outdated.

grid = (bsz,)

update_attn_mask_offsets_kernel[grid](
seq_lens_this_time,
seq_lens_encoder,
cu_seqlens_k,
attention_mask_offset,
BLOCK_M=BLOCK_SIZE,
)
return attention_mask_offset
1 change: 0 additions & 1 deletion fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,6 @@ def _initialize_attn_backend(self) -> None:
encoder_block_shape_q=encoder_block_shape_q,
decoder_block_shape_q=decoder_block_shape_q,
)

self.attn_backends.append(attn_backend)

def _dummy_pooler_run_task(
Expand Down
Loading
Loading