Skip to content

Commit ac2968b

Browse files
author
prefill-dev2
committed
[cuda][prefill] window-aware SDPA: skip fully-masked KV blocks (idea #1)
Block-sparse early-exit in _sdpa_fwd_kernel_body: skip KV blocks that are entirely masked (sliding-window via HAS_MASK sum==0, causal via start_n>max_seq_pos). Exact (skipped blocks are x1,+0 no-ops). Prefill +46-88% all lengths; decode safe; SDPA nsys 58.1%->18.5%. Numerically bf16-exact vs dense+mask (unit test).
1 parent 457a316 commit ac2968b

1 file changed

Lines changed: 62 additions & 40 deletions

File tree

  • backends/cuda/triton/kernels

backends/cuda/triton/kernels/sdpa.py

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -422,21 +422,22 @@ def _sdpa_fwd_kernel_body(
422422

423423
offs_n_init = tl.arange(0, BLOCK_N)
424424

425+
# Window-aware early-exit. A KV block that is fully masked (sliding-window
426+
# or causal) contributes nothing to the online softmax — every entry is
427+
# -inf, so p=0 and m_i/l_i/acc are left unchanged. We detect such blocks up
428+
# front and skip their K/V loads and both matmuls. This is exact: it only
429+
# skips work the mask would have zeroed out anyway. At seq=2048 the 50
430+
# sliding-window(1024) layers and the 10 causal layers each leave roughly
431+
# half (or more) of their KV blocks fully masked, so this is a large cut to
432+
# the dominant prefill cost. The skip condition is a CTA-wide reduction, so
433+
# the branch is uniform and turns into a real skip (not predication).
434+
if IS_CAUSAL:
435+
max_seq_pos = tl.max(seq_pos)
436+
425437
for start_n in tl.range(0, Lk, BLOCK_N):
426438
offs_n = start_n + offs_n_init
427439

428-
# K load: uniform (single KV head, shared across all Q heads in tile)
429-
k_ptrs = K_ptr + (
430-
b * stride_kb
431-
+ h_kv * stride_kh
432-
+ (offs_n[:, None] * stride_kn)
433-
+ (offs_d[None, :] * stride_kd)
434-
)
435-
k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
436-
k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16)
437-
438-
qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32)
439-
440+
# Decide whether any row in this tile actually attends to this KV block.
440441
if HAS_MASK:
441442
mask_ptrs = Mask_ptr + (
442443
b * stride_mb
@@ -445,39 +446,60 @@ def _sdpa_fwd_kernel_body(
445446
)
446447
mn_mask = row_valid[:, None] & (offs_n[None, :] < Lk)
447448
mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False)
448-
qk = tl.where(
449-
mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32)
449+
block_active = tl.sum(mask_block.to(tl.int32)) > 0
450+
elif IS_CAUSAL:
451+
# Block is entirely in the future for every row -> skip.
452+
block_active = start_n <= max_seq_pos
453+
else:
454+
block_active = True
455+
456+
if block_active:
457+
# K load: uniform (single KV head, shared across Q heads in tile)
458+
k_ptrs = K_ptr + (
459+
b * stride_kb
460+
+ h_kv * stride_kh
461+
+ (offs_n[:, None] * stride_kn)
462+
+ (offs_d[None, :] * stride_kd)
450463
)
464+
k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
465+
k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16)
451466

452-
if IS_CAUSAL:
453-
causal = offs_n[None, :] > seq_pos[:, None]
454-
qk = tl.where(
455-
causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk
456-
)
467+
qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32)
457468

458-
m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32))
459-
safe_diff = tl.where(
460-
m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf")
461-
)
462-
p_f32 = tl.exp(safe_diff).to(tl.float32)
463-
l_ij = tl.sum(p_f32, axis=1).to(tl.float32)
464-
safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0)
465-
alpha = tl.exp(safe_alpha_diff).to(tl.float32)
469+
if HAS_MASK:
470+
qk = tl.where(
471+
mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32)
472+
)
466473

467-
# V load: uniform (single KV head)
468-
v_ptrs = V_ptr + (
469-
b * stride_vb
470-
+ h_kv * stride_vh
471-
+ (offs_n[:, None] * stride_vn)
472-
+ (offs_d[None, :] * stride_vd)
473-
)
474-
v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
475-
v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16)
474+
if IS_CAUSAL:
475+
causal = offs_n[None, :] > seq_pos[:, None]
476+
qk = tl.where(
477+
causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk
478+
)
476479

477-
p_bf16 = p_f32.to(tl.bfloat16)
478-
acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32)
479-
l_i = (l_i * alpha + l_ij).to(tl.float32)
480-
m_i = m_ij
480+
m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32))
481+
safe_diff = tl.where(
482+
m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf")
483+
)
484+
p_f32 = tl.exp(safe_diff).to(tl.float32)
485+
l_ij = tl.sum(p_f32, axis=1).to(tl.float32)
486+
safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0)
487+
alpha = tl.exp(safe_alpha_diff).to(tl.float32)
488+
489+
# V load: uniform (single KV head)
490+
v_ptrs = V_ptr + (
491+
b * stride_vb
492+
+ h_kv * stride_vh
493+
+ (offs_n[:, None] * stride_vn)
494+
+ (offs_d[None, :] * stride_vd)
495+
)
496+
v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
497+
v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16)
498+
499+
p_bf16 = p_f32.to(tl.bfloat16)
500+
acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32)
501+
l_i = (l_i * alpha + l_ij).to(tl.float32)
502+
m_i = m_ij
481503

482504
inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0)
483505
acc = acc * inv_l_i[:, None]

0 commit comments

Comments
 (0)