[ilu/ttx] opt swa: decode add fast path when windows cover the sequence;quant prefill use dequant + bf16 swa paged prefill#354
Conversation
sky-fun
commented
Jun 11, 2026
- SWA paged decode: add fast path when Global + Local windows fully cover the sequence; ~10 faster when BS=8,block=32(0.334ms -> 0.0326ms);
- SWA paged prefill with kv dequant: use dequant and tl.dot for QK and PV matmuls. adds a nomask fast path for fully covered local-window blocks when causal masking is unnecessary.~45x faster (116ms -> 2.6ms).
…ows fully cover the sequence
There was a problem hiding this comment.
Code Review
This pull request refactors the sliding window attention (SWA) paged prefill and decode kernels in Triton. It introduces block-based processing helper functions to replace token-by-token loops, adds support for pipelining in decode kernels, and decouples the KV tile width from the page size. A high-severity issue was identified in _swa_acc_fwd_dequant_mxn where entirely masked rows can cause NaN propagation due to subtracting -inf from -inf. A code suggestion was provided to safely handle masked rows.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Claude Code ReviewVerdict: Request changes -- The rewritten paged-prefill SWA kernel uses log2-domain scaling inconsistently with the per-row int8 quant path and reads K/V scales with the wrong head index, which will produce numerically wrong outputs. SummaryThis PR rewrites the paged-prefill SWA-with-KV-dequant kernel from a per-token scalar loop into a tiled (BLOCK_M x BLOCK_N) online-softmax kernel with split global/local/causal-mask regions, plus adds a Must fix
SuggestionsSuggestions (5)
NitsNits (3)
Notes
|
4a74d88 to
17b67b5
Compare
Claude Code ReviewVerdict: Request changes -- The rewritten paged-prefill SWA kernel has a likely correctness bug in K/V scale indexing under GQA and a suspicious sm_scale doubling, and the new decode quant fast-path duplicates KV blocks. SummaryThis PR rewrites the int8 paged-prefill SWA kernel from a per-token scalar implementation to a tiled (BLOCK_M x BLOCK_N) implementation with shared masked/no-mask helpers, and adds a "fast-path" dense loop in three decode kernels for the case where the global+local windows cover the whole KV sequence. It also adds a pipeline-stages heuristic for decode kernels and changes how K/V per-channel scales are sliced under GQA. Must fix
SuggestionsSuggestions (5)
NitsNits (3)
Notes
|
…f16 swa paged prefill [ilu/ttx] merge dequant and swa_infer [ilu/ttx] swa paged prefill quant use exp2 [ilu/ttx] change block_n for occupancy, remove redundant type convert
17b67b5 to
c79c995
Compare
Claude Code ReviewVerdict: Request changes -- The new paged-prefill SWA kernel has a likely-incorrect KV-scale indexing change and a dead SummaryRewrites the paged-prefill SWA-with-int8-KV kernel from a per-token scalar loop to a tiled BLOCK_M x BLOCK_N online-softmax kernel with autotuning, and adds Must fix
SuggestionsSuggestions (4)
NitsNits (3)
Notes
|