[ilu/ttx] opt swa_paged_prefill: remove boundary_check in nomask path…#373
[ilu/ttx] opt swa_paged_prefill: remove boundary_check in nomask path…#373sky-fun wants to merge 1 commit into
Conversation
…,remove tl.trans when BLOCK_D == HEAD_DIM
Claude Code ReviewVerdict: Comment -- Optimizes the no-mask SWA inner loop by pre-transposing K and dropping boundary checks when BLOCK_D == HEAD_DIM, but the K/V boundary-check removal looks unsafe near the tail of kv_seq_len. SummaryThe PR rewrites the no-mask path of the SWA forward kernel to feed Must fix
SuggestionsSuggestions (2)
NitsNits (1)
Notes
|
There was a problem hiding this comment.
Code Review
This pull request optimizes the Triton-based sliding window attention (SWA) kernels by avoiding transpose operations and dropping boundary checks when the head dimension is a power of two (BLOCK_D == HEAD_DIM). This is achieved by constructing the K block pointer in a transposed layout directly at creation time. Additionally, window presence flags are marked as compile-time constants. The review feedback suggests a further optimization in the paged prefill kernel: replacing the dynamic 'kv_block_len' with the compile-time constant 'BLOCK_N' when creating the K block pointer in the 'nomask' loop, which allows the Triton compiler to perform better static analysis and instruction scheduling.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if BLOCK_D == HEAD_DIM: | ||
| k_block_ptr = tl.make_block_ptr( | ||
| base=k_cache_ptr | ||
| + physical_page_id * stride_kp | ||
| + kv_head_id * stride_kh | ||
| + kv_block_start_in_page * stride_kt, | ||
| shape=(HEAD_DIM, kv_block_len), | ||
| strides=(stride_kd, stride_kt), | ||
| offsets=(0, 0), | ||
| block_shape=(BLOCK_D, BLOCK_N), | ||
| order=(0, 1), | ||
| ) | ||
| else: | ||
| k_block_ptr = tl.make_block_ptr( | ||
| base=k_cache_ptr | ||
| + physical_page_id * stride_kp | ||
| + kv_head_id * stride_kh | ||
| + kv_block_start_in_page * stride_kt, | ||
| shape=(kv_block_len, HEAD_DIM), | ||
| strides=(stride_kt, stride_kd), | ||
| offsets=(0, 0), | ||
| block_shape=(BLOCK_N, BLOCK_D), | ||
| order=(1, 0), | ||
| ) |
There was a problem hiding this comment.
In the nomask loop of the paged prefill kernel, the block is mathematically guaranteed to be fully in-bounds, meaning kv_block_len is always exactly equal to BLOCK_N.
By replacing the dynamic kv_block_len with the compile-time constant BLOCK_N in the shape parameter of tl.make_block_ptr, we make the block pointer shape fully static. This allows the Triton compiler to perform better static analysis, address generation, and instruction scheduling optimizations.
if BLOCK_D == HEAD_DIM:
k_block_ptr = tl.make_block_ptr(
base=k_cache_ptr
+ physical_page_id * stride_kp
+ kv_head_id * stride_kh
+ kv_block_start_in_page * stride_kt,
shape=(HEAD_DIM, BLOCK_N),
strides=(stride_kd, stride_kt),
offsets=(0, 0),
block_shape=(BLOCK_D, BLOCK_N),
order=(0, 1),
)
else:
k_block_ptr = tl.make_block_ptr(
base=k_cache_ptr
+ physical_page_id * stride_kp
+ kv_head_id * stride_kh
+ kv_block_start_in_page * stride_kt,
shape=(BLOCK_N, HEAD_DIM),
strides=(stride_kt, stride_kd),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_D),
order=(1, 0),
)
opt swa_paged_prefill:
speed up 40%