Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 186 additions & 29 deletions vllm/v1/attention/backends/mla/sparse_mla_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,128 @@ def _accumulate_indexed_attention_chunk_kernel(
tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask)


_PREFILL_INDEXED_HEAD_BLOCK = 8


@triton.jit
def _accumulate_indexed_attention_chunk_multihead_kernel(
q_ptr,
kv_flat_ptr,
indices_ptr,
lens_ptr,
max_score_ptr,
denom_ptr,
acc_ptr,
stride_q_t: tl.constexpr,
stride_q_h: tl.constexpr,
stride_q_d: tl.constexpr,
stride_kv_t,
stride_kv_d: tl.constexpr,
stride_indices_t: tl.constexpr,
stride_indices_c: tl.constexpr,
stride_state_t: tl.constexpr,
stride_state_h: tl.constexpr,
stride_acc_t: tl.constexpr,
stride_acc_h: tl.constexpr,
stride_acc_d: tl.constexpr,
num_heads: tl.constexpr,
head_dim: tl.constexpr,
num_candidates,
candidate_offset,
scale: tl.constexpr,
HEAD_BLOCK: tl.constexpr,
BLOCK_D: tl.constexpr,
):
token_idx = tl.program_id(0)
head_block_idx = tl.program_id(1)
head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK)
dim_offsets = tl.arange(0, BLOCK_D)
head_mask = head_offsets < num_heads
dim_mask = dim_offsets < head_dim

q = tl.load(
q_ptr
+ token_idx * stride_q_t
+ head_offsets[:, None] * stride_q_h
+ dim_offsets[None, :] * stride_q_d,
mask=head_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(tl.float32)

state_base = token_idx * stride_state_t
running_max = tl.load(
max_score_ptr + state_base + head_offsets * stride_state_h,
mask=head_mask,
other=float("-inf"),
)
running_denom = tl.load(
denom_ptr + state_base + head_offsets * stride_state_h,
mask=head_mask,
other=0.0,
)
acc_base = token_idx * stride_acc_t
running_acc = tl.load(
acc_ptr
+ acc_base
+ head_offsets[:, None] * stride_acc_h
+ dim_offsets[None, :] * stride_acc_d,
mask=head_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(tl.float32)

valid_len = tl.load(lens_ptr + token_idx)
local_eff = tl.minimum(
num_candidates,
tl.maximum(valid_len - candidate_offset, 0),
)

for candidate_idx in range(0, local_eff):
kv_index = tl.load(
indices_ptr
+ token_idx * stride_indices_t
+ candidate_idx * stride_indices_c
)
is_valid = kv_index >= 0

if is_valid:
kv = tl.load(
kv_flat_ptr
+ kv_index.to(tl.int64) * stride_kv_t
+ dim_offsets * stride_kv_d,
mask=dim_mask,
other=0.0,
).to(tl.float32)
scores = tl.sum(q * kv[None, :], axis=1) * scale
next_max = tl.maximum(running_max, scores)
previous_weight = tl.exp(running_max - next_max)
candidate_weight = tl.exp(scores - next_max)
running_acc = (
running_acc * previous_weight[:, None]
+ kv[None, :] * candidate_weight[:, None]
)
running_denom = running_denom * previous_weight + candidate_weight
running_max = next_max

tl.store(
max_score_ptr + state_base + head_offsets * stride_state_h,
running_max,
mask=head_mask,
)
tl.store(
denom_ptr + state_base + head_offsets * stride_state_h,
running_denom,
mask=head_mask,
)
tl.store(
acc_ptr
+ acc_base
+ head_offsets[:, None] * stride_acc_h
+ dim_offsets[None, :] * stride_acc_d,
running_acc,
mask=head_mask[:, None] & dim_mask[None, :],
)


def accumulate_indexed_sparse_mla_attention_chunk(
q: torch.Tensor,
kv_flat: torch.Tensor,
Expand Down Expand Up @@ -1379,35 +1501,70 @@ def accumulate_indexed_sparse_mla_attention_chunk(
num_heads = max_score.shape[1]
num_candidates = indices.shape[1]
block_d = min(1024, triton.next_power_of_2(head_dim))
grid = (num_tokens, num_heads)
_accumulate_indexed_attention_chunk_kernel[grid](
q,
kv_flat,
indices,
lens,
max_score,
denom,
acc,
q.stride(0),
q.stride(1),
q.stride(2),
kv_flat.stride(0),
kv_flat.stride(1),
indices.stride(0),
indices.stride(1),
max_score.stride(0),
max_score.stride(1),
acc.stride(0),
acc.stride(1),
acc.stride(2),
num_heads,
head_dim,
num_candidates,
candidate_offset,
scale,
BLOCK_D=block_d,
# num_warps / num_stages supplied by @triton.autotune above.
)
head_block = _PREFILL_INDEXED_HEAD_BLOCK

if num_heads >= head_block:
grid = (num_tokens, triton.cdiv(num_heads, head_block))
_accumulate_indexed_attention_chunk_multihead_kernel[grid](
q,
kv_flat,
indices,
lens,
max_score,
denom,
acc,
q.stride(0),
q.stride(1),
q.stride(2),
kv_flat.stride(0),
kv_flat.stride(1),
indices.stride(0),
indices.stride(1),
max_score.stride(0),
max_score.stride(1),
acc.stride(0),
acc.stride(1),
acc.stride(2),
num_heads,
head_dim,
num_candidates,
candidate_offset,
scale,
HEAD_BLOCK=head_block,
BLOCK_D=block_d,
num_warps=4,
num_stages=2,
)
else:
grid = (num_tokens, num_heads)
_accumulate_indexed_attention_chunk_kernel[grid](
q,
kv_flat,
indices,
lens,
max_score,
denom,
acc,
q.stride(0),
q.stride(1),
q.stride(2),
kv_flat.stride(0),
kv_flat.stride(1),
indices.stride(0),
indices.stride(1),
max_score.stride(0),
max_score.stride(1),
acc.stride(0),
acc.stride(1),
acc.stride(2),
num_heads,
head_dim,
num_candidates,
candidate_offset,
scale,
BLOCK_D=block_d,
)



@triton.autotune(
Expand Down
8 changes: 0 additions & 8 deletions vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,6 @@ def deepseek_v4_sm12x_fp8_einsum(
BLOCK_TOKENS=block_tokens,
BLOCK_OUT=block_out,
BLOCK_HIDDEN=block_hidden,
# Pinned to the SM12x-optimal config: a previous ``@triton.autotune``
# block selected from {num_warps in {4,8}, num_stages in {2,3}} with
# key=["num_tokens", ...]. ``num_tokens`` varies per request, so the
# autotune cache missed every call and the 4-config bench replayed
# on every shape — pure overhead. The other three keys are
# model-architecture-fixed, so the same config (num_warps=4,
# num_stages=3) always won; we pin it directly. Reported by
# ``alexbi29`` in PR #41834 comment 4464750956.
num_warps=4,
num_stages=3,
)
Expand Down