Skip to content

[ilu/ttx] opt swa: decode add fast path when windows cover the sequence;quant prefill use dequant + bf16 swa paged prefill#354

Open
sky-fun wants to merge 2 commits into
masterfrom
ilu/ttx_opt_swa_quant_tmp
Open

[ilu/ttx] opt swa: decode add fast path when windows cover the sequence;quant prefill use dequant + bf16 swa paged prefill#354
sky-fun wants to merge 2 commits into
masterfrom
ilu/ttx_opt_swa_quant_tmp

Conversation

@sky-fun

@sky-fun sky-fun commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator
  • SWA paged decode: add fast path when Global + Local windows fully cover the sequence; ~10 faster when BS=8,block=32(0.334ms -> 0.0326ms);
  • SWA paged prefill with kv dequant: use dequant and tl.dot for QK and PV matmuls. adds a nomask fast path for fully covered local-window blocks when causal masking is unnecessary.~45x faster (116ms -> 2.6ms).

@sky-fun sky-fun requested a review from madengfei June 11, 2026 06:13

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the sliding window attention (SWA) paged prefill and decode kernels in Triton. It introduces block-based processing helper functions to replace token-by-token loops, adds support for pipelining in decode kernels, and decouples the KV tile width from the page size. A high-severity issue was identified in _swa_acc_fwd_dequant_mxn where entirely masked rows can cause NaN propagation due to subtracting -inf from -inf. A code suggestion was provided to safely handle masked rows.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread mojo_opset/backends/ttx/kernels/ilu/swa.py Outdated
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- The rewritten paged-prefill SWA kernel uses log2-domain scaling inconsistently with the per-row int8 quant path and reads K/V scales with the wrong head index, which will produce numerically wrong outputs.

Summary

This PR rewrites the paged-prefill SWA-with-KV-dequant kernel from a per-token scalar loop into a tiled (BLOCK_M x BLOCK_N) online-softmax kernel with split global/local/causal-mask regions, plus adds a PIPELINE_STAGES fast-path loop in three paged-decode kernels. New helper jit functions _swa_acc_fwd_dequant_* and _swa_acc_fwd_int8_* are factored out for the inner accumulation step.

Must fix

  • [BLOCKER] K/V scale gather uses wrong head dimension -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1752-1769 -- The host code now slices k_qscale/v_qscale by gqa_ratio ("k_qscale[::gqa_ratio]" / "k_qscale[:num_kv_heads]") and passes that to the kernel, then loads with kv_head_id * stride_ks_h. But k_qscale is a per-(q_head, d) tensor in the old kernel (loaded with q_head_id * stride_ks_h); slicing it by gqa_ratio only makes sense if scales are actually per-kv-head. Either the input contract changed silently (call sites still pass per-q-head scales) or the indexing is wrong. Verify the layout and either keep q_head_id indexing or document/enforce per-kv-head input.
  • [BLOCKER] sm_scale is pre-multiplied by log2e but int8 path expects natural-log scale -- swa.py:1839, 357-360, 401-404 -- Host passes softmax_scale * LOG2E and the dequant kernels use tl.math.exp2 (correct). But _swa_acc_fwd_int8_mxn/_nomask compute qk = qk * q_quant_scale[:, None] * sm_scale and then tl.math.exp2(qk - m_ij), which is consistent only if sm_scale already includes log2e. That part is fine, however the final output for the int8 path uses acc * (1.0/127.0) * v_scale_vec / l_i_safe -- but the int8 helpers fold q_quant_scale into qk while p is requantized with scale 1/127, dropping the q_quant_scale factor on the V side; the old kernel also dropped q_quant_scale post-softmax (only on probabilities). Re-derive: with softmax over qk*q_qs*sm, q_quant_scale is absorbed into the exp argument so it should NOT appear in the output -- this is fine. But please double-check, because the dequant (non-int8) path applies k_scale/v_scale inside the helper while the int8 path applies v_scale only at the end and never applies k_scale (it's folded into q_scaled pre-quant). Easy to introduce a 1/127 vs k_scale mismatch.
  • [BLOCKER] q_block_start passed to tl.make_block_ptr offsets as Python int after min() -- swa.py:1395 -- q_block_end = min(q_block_start + BLOCK_M, q_seq_len) mixes a tl tensor (q_seq_len from tl.load) with Python min; in Triton this returns a tl value, but then q_block_len = q_block_end - q_block_start is used as a runtime length passed to make_block_ptr(shape=(q_seq_len, HEAD_DIM), offsets=(q_block_start.to(tl.int32),...)). The shape argument is q_seq_len (a runtime tl.int32) which is fine, but ensure q_block_start is also tl.int32 (it is kv_block_id*BLOCK_M -- a Python/tl scalar). Re-verify; subtle bugs here typically manifest as silent OOB loads.
  • [BLOCKER] q_valid zeroing of output may write garbage for partial tail tile -- swa.py:1722-1734 -- For the last q block where q_block_len < BLOCK_M, out_block = tl.where(q_valid[:, None], out_block, 0.0) zeros invalid rows but out_mask = q_valid[:, None] & (offs_d[None,:] < HEAD_DIM) then suppresses the store. That is correct, but note that for COMPUTE_INT8 the divide acc / l_i_safe over invalid rows divides by 1.0 of zeroed numerator -- OK. However m_i initialized to -inf for invalid rows leads to l_i==0, l_i_safe=1e-6 (not 1.0) for the int8 branch -- harmless because masked off, but inconsistent with the dequant branch which uses l_i>0 ? .. : 0. Worth unifying.

Suggestions

Suggestions (5)
  • [MAJOR] Massive code duplication across the three kv-block loops -- swa.py:1471-1720 -- The global/local/tail loops repeat identical make_block_ptr plumbing 4-5 times; extract a helper or at minimum the K/V block-ptr construction. Current form is a maintenance hazard and the next mask-region edit will silently diverge.
  • [MAJOR] non_global_window_start_block - 1 underflow -- swa.py:1554 -- full_local_end_block = non_global_window_start_block - 1 is set unconditionally before can_use_nomask_local is finalized; if non_global_window_start_block == 0 this is -1. It is later overwritten when can_use_nomask_local, but the dead initialization is confusing.
  • [MAJOR] Autotune key omits PAGE_SIZE and IS_CAUSAL -- swa.py:1287-1291 -- Different page sizes / causal flags pick different kv-block layouts; not keying on them risks reusing a stale-best config.
  • [MINOR] max_seqlen_q and compute_dtype silently ignored -- swa.py:1755, 1778-1780 -- Doc says "Unused; retained for API compatibility" but the only effective compute_dtype branch is torch.int8; bf16/fp16 selection is now driven solely by q.dtype. Remove the parameter or honor it; "ignored" parameters are a footgun for the modeling layer.
  • [MINOR] _paged_decode_pipeline_stages smem budget hard-codes 112 KB -- swa.py:3422 -- This is architecture-specific (ILU?) and there is no comment tying it to a target. At minimum hoist the constant and document; ideally query the device.

Nits

Nits (3)
  • [NIT] from typing import Tuple added but Tuple is unused in the diff -- swa.py:9.
  • [NIT] Comment "smem limit 128KB" lacks a space before # -- swa.py:3423.
  • [NIT] Inconsistent naming q_quant_scale_vec vs q_quant_scale between batched and decode kernels makes cross-reading harder -- swa.py:1413.

Notes

  • [CHECK] Confirm callers of swa_paged_prefill_with_kv_dequant_impl actually pass per-q-head k_qscale/v_qscale (shape [num_q_heads, D]); the new [::gqa_ratio] slicing assumes that and silently produces wrong scales otherwise.
  • [CHECK] The int8 P-requantization p_quant = ... .to(tl.int8).to(q_ty) then tl.dot(p_quant, v.to(q_ty)) -- verify the ILU compiler does not regress into the same "int8 provenance into tl.dot" bug the old scalar kernel was avoiding (see commit 1130f65 noted in the docstring).
  • [CHECK] LOG2E.value -- ensure LOG2E is imported in this file (not shown in diff); a NameError here would only surface at call time.

@sky-fun sky-fun force-pushed the ilu/ttx_opt_swa_quant_tmp branch from 4a74d88 to 17b67b5 Compare June 11, 2026 06:53
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- The rewritten paged-prefill SWA kernel has a likely correctness bug in K/V scale indexing under GQA and a suspicious sm_scale doubling, and the new decode quant fast-path duplicates KV blocks.

Summary

This PR rewrites the int8 paged-prefill SWA kernel from a per-token scalar implementation to a tiled (BLOCK_M x BLOCK_N) implementation with shared masked/no-mask helpers, and adds a "fast-path" dense loop in three decode kernels for the case where the global+local windows cover the whole KV sequence. It also adds a pipeline-stages heuristic for decode kernels and changes how K/V per-channel scales are sliced under GQA.

Must fix

  • [BLOCKER] Decode quant fast-path double-counts KV -- mojo_opset/backends/ttx/kernels/ilu/swa.py:3223-3286 -- When (GLOBAL_WINDOW + LOCAL_WINDOW) >= kv_seq_len, fast_loop_end = num_total_blocks runs the full range, but gw_loop_end/local_loop_start are reset such that the global-window loop is skipped only if the condition is true; however the block immediately following (for kv_block_id in tl.range(0, gw_loop_end)) still runs whenever gw_loop_end > 0, and on the non-fast-path the fast-path itself also executes (it iterates 0..fast_loop_end=0, which is fine), but in the fast path, both the fast loop and the local-window loop's local_loop_start = num_total_blocks correctly skip; verify the local loop guard -- the issue is that on the non-quant decode kernels (_paged_decode_kernel, _paged_decode_kernel_tiny_global) the same pattern is fine, but in _paged_decode_quant_kernel the fast-path lacks the softmax_scale * LOG2E convention used elsewhere -- please confirm both loops are mutually exclusive in all branches and add a unit test for the full-cover case.
  • [BLOCKER] K/V scale slicing assumes contiguous GQA grouping -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1798-1807 -- k_qscale[::gqa_ratio] picks heads 0, gqa_ratio, 2*gqa_ratio, ... but k_qscale is shape (num_kv_heads, D) (indexed by kv_head_id inside the kernel, see K_qscale + kv_head_id * stride_ks_h); the old kernel indexed by q_head_id so this slicing was needed there, but the new kernel reads kv_head_id, so the host slicing is wrong and will pick the wrong KV heads for gqa_ratio > 1. Pass k_qscale/v_qscale directly without slicing.
  • [BLOCKER] sm_scale applied twice in int8 path -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1843 and :336, :387 -- Host passes softmax_scale * LOG2E as sm_scale, and _swa_acc_fwd_int8_mxn/_nomask_mxn compute qk * q_quant_scale * sm_scale then take exp2, which is correct; but _swa_acc_fwd_dequant_mxn/_nomask_mxn also multiply by qk_scale and use exp2 -- that part is consistent. However the q_scaled = q_block * k_scale_vec uses k_scale_vec per the dequant convention while the int8 path does NOT pre-multiply K by k_scale. Confirm K dequant scale is folded into the activation quant; otherwise int8 results will be missing the K per-channel scale entirely.
  • [BLOCKER] non_global_window_start_block computation may underflow when no global window -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1463 -- The "fall-through" branch iterates non_global_window_start_block..num_total_blocks; if _swa_split_blocks returns the same start as the global-loop end this is fine, but in the can_use_nomask_local=False path the global-window blocks (0..num_global_window_blocks) are run and then the loop restarts at non_global_window_start_block. Verify those two ranges do not overlap when GLOBAL_WINDOW_SIZE is None; the prior decode helper kept them disjoint, please add an assertion or test.

Suggestions

Suggestions (5)
  • [MAJOR] Massive code duplication across kv loops -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1410-1735 -- The four KV-loop bodies (global, pre-local-masked, local-no-mask, post-local-masked, fall-through) repeat ~30 lines of block_ptr setup each; factor into a @triton.jit helper similar to _swa_acc_fwd_*_mxn to reduce maintenance cost and the chance of divergence between branches.
  • [MAJOR] compute_dtype arg silently ignored except for int8 toggle -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1755-1778 -- Docstring says it is ignored, but it is still consulted via compute_dtype == torch.int8; rename to use_int8_compute: bool or assert valid values to avoid the silent-fallback footgun.
  • [MAJOR] Pipeline stages disabled comment contradicts code -- mojo_opset/backends/ttx/kernels/ilu/swa.py:3231-3234 -- Comment "Pipelining is disabled" but the loop uses default num_stages (compiler-chosen); if pipelining genuinely corrupts state, set num_stages=1 explicitly. Also PIPELINE_STAGES is computed and passed but unused in the quant kernel's fast loop, leaving dead config.
  • [MINOR] Magic constant 1e-6 for l_i floor -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1727 -- The int8 branch uses tl.maximum(l_i, 1e-6) while the bf16 branch correctly returns 0 when l_i == 0; the int8 branch will produce nonzero output for fully-masked rows. Match the bf16 branch's tl.where(l_i > 0, ...) form.
  • [MINOR] _paged_decode_pipeline_stages hardcodes 112KB SMEM -- mojo_opset/backends/ttx/kernels/ilu/swa.py:3424-3429 -- The 112KB / 128KB constants are device-specific; move to a backend capability lookup or document the assumed device.

Nits

Nits (3)
  • [NIT] from typing import Optional, Tuple -- mojo_opset/backends/ttx/kernels/ilu/swa.py:9 -- Tuple is imported but not used in the diff.
  • [NIT] Trailing comment style # smem limit 128KB -- mojo_opset/backends/ttx/kernels/ilu/swa.py:3425 -- missing space before #.
  • [NIT] Inconsistent loop iteration: range vs tl.range -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1376,1410,... -- the new kernel mixes Python range and tl.range; pick one for readability.

Notes

  • [CHECK] softmax_scale * LOG2E.value is now applied at host level for the prefill kernel, while the prior version used tl.math.exp with raw sm_scale; please verify the int8 helpers' exp2 math matches the bf16 reference numerically (especially the m_i initialization to -inf interacting with exp2(m_i - m_ij) when m_ij is also -inf).
  • [CHECK] _swa_split_blocks is referenced from the new prefill kernel but not shown in the diff; confirm its semantics match what the new loop structure expects (specifically, the num_global_window_blocks and non_global_window_start_block returned for the non-causal case).
  • [CHECK] The fast-path in _paged_decode_kernel_tiny_global sets nomask_start = num_total_blocks only inside the full_cover branch but earlier code already set local_loop_start = num_total_blocks; ensure the original for kv_block_id in tl.range(non_global_window_start_block, nomask_start) block is still reachable in non-full-cover cases.

…f16 swa paged prefill

[ilu/ttx] merge dequant and swa_infer

[ilu/ttx] swa paged prefill quant use exp2

[ilu/ttx] change block_n for occupancy, remove redundant type convert
@sky-fun sky-fun force-pushed the ilu/ttx_opt_swa_quant_tmp branch from 17b67b5 to c79c995 Compare June 11, 2026 07:51
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- The new paged-prefill SWA kernel has a likely-incorrect KV-scale indexing change and a dead PIPELINE_STAGES parameter, plus a suspicious mask-handling pattern.

Summary

Rewrites the paged-prefill SWA-with-int8-KV kernel from a per-token scalar loop to a tiled BLOCK_M x BLOCK_N online-softmax kernel with autotuning, and adds PIPELINE_STAGES plumbing (currently unused) to the paged-decode kernels along with a global+local "full-cover" fast path.

Must fix

  • [BLOCKER] K/V scale indexing changed from per-Q-head to per-KV-head with manual stride -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1788-1797 -- The host now slices k_qscale[::gqa_ratio] (or [:num_kv_heads]) and passes its strides, but the kernel still indexes with kv_head_id * stride_ks_h. If k_qscale/v_qscale are actually shaped (num_q_heads, D) (as the previous kernel assumed by indexing with q_head_id), this either (a) silently produces wrong scales when GQA groups have different per-head scales or (b) breaks if scales were truly per-Q-head. Verify the canonical scale layout and either keep per-Q-head indexing or document/enforce per-KV-head shape.
  • [BLOCKER] PIPELINE_STAGES is plumbed through but never used -- mojo_opset/backends/ttx/kernels/ilu/swa.py:2517,2758,3172 -- The constexpr is added to all three decode kernels and computed/passed from the host, but no tl.range(..., num_stages=PIPELINE_STAGES) (or similar) consumes it. Either wire it into the hot loops or drop it; as-is it is dead code that suggests a half-applied change and adds recompiles per distinct value.
  • [BLOCKER] mask is False / mask is not True in @triton.jit helpers -- mojo_opset/backends/ttx/kernels/ilu/swa.py:288,294,353,359 -- These identity checks against Python singletons inside Triton-jitted code are fragile: mask is a tile here, not a Python bool, so mask is False is always false and mask is not None and mask is not True is always true. The early-return branch is dead and the conditional is effectively unconditional. If the intent was a constexpr toggle, take a MASK: tl.constexpr instead.

Suggestions

Suggestions (4)
  • [MAJOR] Massive code duplication across three loop variants -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1430-1736 -- The pre-global, full-local-nomask, post-local, and fallback loops repeat ~40 lines of make_block_ptr setup nearly verbatim; factor into a helper to reduce divergence risk between the masked/nomask paths.
  • [MAJOR] q_block_start.to(tl.int32) passed as offsets to make_block_ptr -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1391 -- tl.make_block_ptr offsets must be compile-time-constant or tl.int32 scalars; mixing min(...) results from q_block_end and Python int arithmetic with .to(tl.int32) here is fine, but q_block_len = q_block_end - q_block_start then feeds boundary_check-only logic. Confirm Triton accepts a runtime offset on this backend.
  • [MAJOR] Full-cover fast-path skips global-window blocks entirely -- mojo_opset/backends/ttx/kernels/ilu/swa.py:2563-2569,2806-2812,3225-3231 -- When GLOBAL_WINDOW + LOCAL_WINDOW >= kv_seq_len the kernel takes a single dense pass with only boundary masking and no causal/local/global mask. That is correct only if every kv index in [0, kv_seq_len) is admitted by the SWA mask for this query, which is true for decode (single q at the end) but verify it cannot fire in a context where some kv index is still excluded (e.g. with non-causal or unusual kv_seq_len vs query position relationships).
  • [MINOR] compute_dtype silently ignored -- mojo_opset/backends/ttx/kernels/ilu/swa.py:1779-1781 -- Docstring says it is ignored but COMPUTE_INT8 = compute_dtype == torch.int8 still uses it to pick the int8 path; the doc is misleading.

Nits

Nits (3)
  • [NIT] Stray inline comment style # smem limit 128KB lacks space -- mojo_opset/backends/ttx/kernels/ilu/swa.py:3428.
  • [NIT] Tuple imported but not obviously used in the diff -- mojo_opset/backends/ttx/kernels/ilu/swa.py:9.
  • [NIT] Long disabled-pipelining comments duplicated across three kernels -- consolidate into a module-level note -- swa.py:2562,2799,3232.

Notes

  • [CHECK] The new tiled kernel uses sm_scale = softmax_scale * LOG2E with tl.math.exp2, while the int8 path multiplies qk * q_quant_scale * sm_scale; confirm this matches the previous kernel's numerics (which used natural exp and raw sm_scale) end-to-end against a reference.
  • [CHECK] _swa_split_blocks is called with q_block_start + kv_computed_len as the q absolute start; confirm its contract expects the absolute KV-aligned position rather than a per-batch-relative one.
  • [CHECK] if page_size >= 128: block_n = 64 -- prior kernel used min(128, next_pow2(page_size)); confirm autotune configs cover BLOCK_N=64 and that no caller relies on BLOCK_N==page_size for correctness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant