Skip to content

[ilu/ttx] opt swa_paged_prefill: remove boundary_check in nomask path…#373

Open
sky-fun wants to merge 1 commit into
masterfrom
ilu/ttx_opt_swa_prefill_tmp
Open

[ilu/ttx] opt swa_paged_prefill: remove boundary_check in nomask path…#373
sky-fun wants to merge 1 commit into
masterfrom
ilu/ttx_opt_swa_prefill_tmp

Conversation

@sky-fun

@sky-fun sky-fun commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

opt swa_paged_prefill:

  • remove boundary_check in nomask path
  • remove tl.trans when BLOCK_D == HEAD_DIM
    speed up 40%

@sky-fun sky-fun requested a review from madengfei June 22, 2026 06:54
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Comment -- Optimizes the no-mask SWA inner loop by pre-transposing K and dropping boundary checks when BLOCK_D == HEAD_DIM, but the K/V boundary-check removal looks unsafe near the tail of kv_seq_len.

Summary

The PR rewrites the no-mask path of the SWA forward kernel to feed tl.dot a pre-transposed K (avoiding tl.trans and enabling MMA-style loads) and skips boundary_check on K/V loads when BLOCK_D == HEAD_DIM. The block_ptr construction in _swa_infer_kernel and _swa_paged_prefill_kernel is updated to match.

Must fix

  • [BLOCKER] Unchecked K/V load may read past kv_seq_len in infer path -- mojo_opset/backends/ttx/kernels/ilu/swa.py:251-260 -- The comment claims the nomask region "guarantees full in-bounds tiles", but in _swa_infer_kernel the block_ptr shape is (HEAD_DIM, kv_seq_len) / (kv_seq_len, HEAD_DIM). If kv_seq_len is not a multiple of BLOCK_N, dropping boundary_check on the last full nomask tile reads OOB. Verify full_local_end_block cannot reach a partial tail tile, or keep boundary_check.

Suggestions

Suggestions (2)
  • [MAJOR] Verify nomask tile alignment invariant -- mojo_opset/backends/ttx/kernels/ilu/swa.py:251-280 -- The correctness of removing boundary_check rests on (a) PAGE_SIZE % BLOCK_N == 0 (asserted) and (b) every nomask iteration landing on a fully-populated BLOCK_N tile. Add a tl.static_assert or runtime check / comment proof for the infer (non-paged) path where kv_seq_len need not be a multiple of BLOCK_N.
  • [MINOR] Duplication of block_ptr construction -- mojo_opset/backends/ttx/kernels/ilu/swa.py:498-520, 897-925 -- The if BLOCK_D == HEAD_DIM branches duplicate ~10 lines each in two kernels; consider a small tl.constexpr helper or unifying via swapped shape/strides/order tuples to reduce drift risk.

Nits

Nits (1)
  • [NIT] Non-ASCII bullet character in kernel-side comments -- mojo_opset/backends/ttx/kernels/ilu/swa.py:251-260 -- repo appears to use plain ASCII elsewhere; use - for consistency.

Notes

  • [CHECK] _swa_acc_fwd_nomask_mxn is also called from the paged-prefill kernel where kv_block_len may be < BLOCK_N on the last page; confirm the caller never invokes the nomask helper for a partial tile, otherwise the unchecked load regresses correctness there too.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the Triton-based sliding window attention (SWA) kernels by avoiding transpose operations and dropping boundary checks when the head dimension is a power of two (BLOCK_D == HEAD_DIM). This is achieved by constructing the K block pointer in a transposed layout directly at creation time. Additionally, window presence flags are marked as compile-time constants. The review feedback suggests a further optimization in the paged prefill kernel: replacing the dynamic 'kv_block_len' with the compile-time constant 'BLOCK_N' when creating the K block pointer in the 'nomask' loop, which allows the Triton compiler to perform better static analysis and instruction scheduling.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +904 to +927
if BLOCK_D == HEAD_DIM:
k_block_ptr = tl.make_block_ptr(
base=k_cache_ptr
+ physical_page_id * stride_kp
+ kv_head_id * stride_kh
+ kv_block_start_in_page * stride_kt,
shape=(HEAD_DIM, kv_block_len),
strides=(stride_kd, stride_kt),
offsets=(0, 0),
block_shape=(BLOCK_D, BLOCK_N),
order=(0, 1),
)
else:
k_block_ptr = tl.make_block_ptr(
base=k_cache_ptr
+ physical_page_id * stride_kp
+ kv_head_id * stride_kh
+ kv_block_start_in_page * stride_kt,
shape=(kv_block_len, HEAD_DIM),
strides=(stride_kt, stride_kd),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_D),
order=(1, 0),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the nomask loop of the paged prefill kernel, the block is mathematically guaranteed to be fully in-bounds, meaning kv_block_len is always exactly equal to BLOCK_N.

By replacing the dynamic kv_block_len with the compile-time constant BLOCK_N in the shape parameter of tl.make_block_ptr, we make the block pointer shape fully static. This allows the Triton compiler to perform better static analysis, address generation, and instruction scheduling optimizations.

                    if BLOCK_D == HEAD_DIM:
                        k_block_ptr = tl.make_block_ptr(
                            base=k_cache_ptr
                            + physical_page_id * stride_kp
                            + kv_head_id * stride_kh
                            + kv_block_start_in_page * stride_kt,
                            shape=(HEAD_DIM, BLOCK_N),
                            strides=(stride_kd, stride_kt),
                            offsets=(0, 0),
                            block_shape=(BLOCK_D, BLOCK_N),
                            order=(0, 1),
                        )
                    else:
                        k_block_ptr = tl.make_block_ptr(
                            base=k_cache_ptr
                            + physical_page_id * stride_kp
                            + kv_head_id * stride_kh
                            + kv_block_start_in_page * stride_kt,
                            shape=(BLOCK_N, HEAD_DIM),
                            strides=(stride_kt, stride_kd),
                            offsets=(0, 0),
                            block_shape=(BLOCK_N, BLOCK_D),
                            order=(1, 0),
                        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant