@@ -422,21 +422,22 @@ def _sdpa_fwd_kernel_body(
422422
423423 offs_n_init = tl .arange (0 , BLOCK_N )
424424
425+ # Window-aware early-exit. A KV block that is fully masked (sliding-window
426+ # or causal) contributes nothing to the online softmax — every entry is
427+ # -inf, so p=0 and m_i/l_i/acc are left unchanged. We detect such blocks up
428+ # front and skip their K/V loads and both matmuls. This is exact: it only
429+ # skips work the mask would have zeroed out anyway. At seq=2048 the 50
430+ # sliding-window(1024) layers and the 10 causal layers each leave roughly
431+ # half (or more) of their KV blocks fully masked, so this is a large cut to
432+ # the dominant prefill cost. The skip condition is a CTA-wide reduction, so
433+ # the branch is uniform and turns into a real skip (not predication).
434+ if IS_CAUSAL :
435+ max_seq_pos = tl .max (seq_pos )
436+
425437 for start_n in tl .range (0 , Lk , BLOCK_N ):
426438 offs_n = start_n + offs_n_init
427439
428- # K load: uniform (single KV head, shared across all Q heads in tile)
429- k_ptrs = K_ptr + (
430- b * stride_kb
431- + h_kv * stride_kh
432- + (offs_n [:, None ] * stride_kn )
433- + (offs_d [None , :] * stride_kd )
434- )
435- k_mask = (offs_n [:, None ] < Lk ) & (offs_d [None , :] < HEAD_DIM )
436- k = tl .load (k_ptrs , mask = k_mask , other = 0.0 ).to (tl .bfloat16 )
437-
438- qk = (tl .dot (q , tl .trans (k )).to (tl .float32 ) * sm_scale ).to (tl .float32 )
439-
440+ # Decide whether any row in this tile actually attends to this KV block.
440441 if HAS_MASK :
441442 mask_ptrs = Mask_ptr + (
442443 b * stride_mb
@@ -445,39 +446,60 @@ def _sdpa_fwd_kernel_body(
445446 )
446447 mn_mask = row_valid [:, None ] & (offs_n [None , :] < Lk )
447448 mask_block = tl .load (mask_ptrs , mask = mn_mask , other = False )
448- qk = tl .where (
449- mask_block , qk , tl .full (qk .shape , - float ("inf" ), dtype = tl .float32 )
449+ block_active = tl .sum (mask_block .to (tl .int32 )) > 0
450+ elif IS_CAUSAL :
451+ # Block is entirely in the future for every row -> skip.
452+ block_active = start_n <= max_seq_pos
453+ else :
454+ block_active = True
455+
456+ if block_active :
457+ # K load: uniform (single KV head, shared across Q heads in tile)
458+ k_ptrs = K_ptr + (
459+ b * stride_kb
460+ + h_kv * stride_kh
461+ + (offs_n [:, None ] * stride_kn )
462+ + (offs_d [None , :] * stride_kd )
450463 )
464+ k_mask = (offs_n [:, None ] < Lk ) & (offs_d [None , :] < HEAD_DIM )
465+ k = tl .load (k_ptrs , mask = k_mask , other = 0.0 ).to (tl .bfloat16 )
451466
452- if IS_CAUSAL :
453- causal = offs_n [None , :] > seq_pos [:, None ]
454- qk = tl .where (
455- causal , tl .full (qk .shape , - float ("inf" ), dtype = tl .float32 ), qk
456- )
467+ qk = (tl .dot (q , tl .trans (k )).to (tl .float32 ) * sm_scale ).to (tl .float32 )
457468
458- m_ij = tl .maximum (m_i , tl .max (qk , axis = 1 ).to (tl .float32 ))
459- safe_diff = tl .where (
460- m_ij [:, None ] > - float ("inf" ), qk - m_ij [:, None ], - float ("inf" )
461- )
462- p_f32 = tl .exp (safe_diff ).to (tl .float32 )
463- l_ij = tl .sum (p_f32 , axis = 1 ).to (tl .float32 )
464- safe_alpha_diff = tl .where (m_ij > - float ("inf" ), m_i - m_ij , 0.0 )
465- alpha = tl .exp (safe_alpha_diff ).to (tl .float32 )
469+ if HAS_MASK :
470+ qk = tl .where (
471+ mask_block , qk , tl .full (qk .shape , - float ("inf" ), dtype = tl .float32 )
472+ )
466473
467- # V load: uniform (single KV head)
468- v_ptrs = V_ptr + (
469- b * stride_vb
470- + h_kv * stride_vh
471- + (offs_n [:, None ] * stride_vn )
472- + (offs_d [None , :] * stride_vd )
473- )
474- v_mask = (offs_n [:, None ] < Lk ) & (offs_d [None , :] < HEAD_DIM )
475- v = tl .load (v_ptrs , mask = v_mask , other = 0.0 ).to (tl .bfloat16 )
474+ if IS_CAUSAL :
475+ causal = offs_n [None , :] > seq_pos [:, None ]
476+ qk = tl .where (
477+ causal , tl .full (qk .shape , - float ("inf" ), dtype = tl .float32 ), qk
478+ )
476479
477- p_bf16 = p_f32 .to (tl .bfloat16 )
478- acc = (acc * alpha [:, None ] + tl .dot (p_bf16 , v )).to (tl .float32 )
479- l_i = (l_i * alpha + l_ij ).to (tl .float32 )
480- m_i = m_ij
480+ m_ij = tl .maximum (m_i , tl .max (qk , axis = 1 ).to (tl .float32 ))
481+ safe_diff = tl .where (
482+ m_ij [:, None ] > - float ("inf" ), qk - m_ij [:, None ], - float ("inf" )
483+ )
484+ p_f32 = tl .exp (safe_diff ).to (tl .float32 )
485+ l_ij = tl .sum (p_f32 , axis = 1 ).to (tl .float32 )
486+ safe_alpha_diff = tl .where (m_ij > - float ("inf" ), m_i - m_ij , 0.0 )
487+ alpha = tl .exp (safe_alpha_diff ).to (tl .float32 )
488+
489+ # V load: uniform (single KV head)
490+ v_ptrs = V_ptr + (
491+ b * stride_vb
492+ + h_kv * stride_vh
493+ + (offs_n [:, None ] * stride_vn )
494+ + (offs_d [None , :] * stride_vd )
495+ )
496+ v_mask = (offs_n [:, None ] < Lk ) & (offs_d [None , :] < HEAD_DIM )
497+ v = tl .load (v_ptrs , mask = v_mask , other = 0.0 ).to (tl .bfloat16 )
498+
499+ p_bf16 = p_f32 .to (tl .bfloat16 )
500+ acc = (acc * alpha [:, None ] + tl .dot (p_bf16 , v )).to (tl .float32 )
501+ l_i = (l_i * alpha + l_ij ).to (tl .float32 )
502+ m_i = m_ij
481503
482504 inv_l_i = tl .where (l_i > 0 , 1.0 / l_i , 0.0 )
483505 acc = acc * inv_l_i [:, None ]
0 commit comments