Skip to content

Genesis 6582#371

Merged
Neuromancer42 merged 16 commits into
masterfrom
genesis-6582
Jun 25, 2026
Merged

Genesis 6582#371
Neuromancer42 merged 16 commits into
masterfrom
genesis-6582

Conversation

@seainair

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for 4D query tensors (with sequence lengths up to 4) in the sliding window attention (SWA) paged decode kernel, updating the Triton kernel, PyTorch operator, and associated tests. The review feedback highlights critical issues in the Triton kernel implementation, specifically the incorrect reuse of query strides instead of output strides when calculating output pointers (which can cause memory corruption if the query is non-contiguous), and a potential division-by-zero risk when the sum of probabilities is zero.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread mojo_opset/backends/ttx/kernels/mlu/swa.py Outdated
Comment thread mojo_opset/backends/ttx/kernels/mlu/swa.py Outdated
Comment thread mojo_opset/backends/ttx/kernels/mlu/swa.py
Comment thread mojo_opset/backends/ttx/kernels/mlu/swa.py Outdated
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Output strides incorrectly reuse query strides, which will corrupt output for any non-contiguous or differently-shaped output tensor.

Summary

Extends swa_paged_decode to support a multi-token query dimension (seq_len up to 4) in both the Triton kernel and the reference Python implementation, plus expanded test configs. The kernel signature gains stride_qs/stride_os and Q_SEQLEN, and the reference path branches on query.ndim == 4.

Must fix

  • [BLOCKER] Output strides set to query strides -- mojo_opset/backends/ttx/kernels/mlu/swa.py:960-963 -- The four stride_o* kernel arguments are passed stride_qb/qs/qh/qd instead of o.stride(...). Since o = torch.zeros_like(query) here this often happens to match, but it is wrong by construction and will silently break any caller whose output layout differs from the input. Pass o.stride(0..3) (handling the 3D vs 4D case symmetrically with the q-stride logic).
  • [BLOCKER] stride_qs defaults to 1 in 3D path -- mojo_opset/backends/ttx/kernels/mlu/swa.py:915-917 -- When q.ndim == 3, stride_qs = 1 is fabricated. With Q_SEQLEN=1 the offs_s[:, None, None] * stride_qs term is zero so it currently works, but if a future caller passes a 3D tensor while compiling with Q_SEQLEN>1, addresses will alias. Set stride_qs = 0 (or assert Q_SEQLEN == 1 in the 3D branch) to make the invariant explicit.
  • [BLOCKER] seq_lens <= 4 assertion is silent and undocumented -- mojo_opset/backends/ttx/kernels/mlu/swa.py:924 -- Bare assert seq_lens <= 4 with no message; assertions can be stripped under -O and the limit is not reflected in the public docstring/typing of MojoPagedDecodeSWA.forward. Replace with an explicit raise ValueError(...) and document the limit.

Suggestions

Suggestions (3)
  • [MAJOR] Reference impl uses query.ndim after rebinding o -- mojo_opset/core/operators/attention.py:743-751 -- The branching on query.ndim inside the loop is fine, but consider hoisting the 3D/4D split out of the per-batch loop to avoid repeated branching and make the two code paths easier to audit.
  • [MINOR] Trailing-whitespace-only line removed but blank line below _swa_split_blocks call also dropped -- mojo_opset/backends/ttx/kernels/mlu/swa.py:794 -- Unrelated formatting change mixed into a functional PR; keep the diff minimal.
  • [MINOR] Test ids duplicated across seq_len variants -- mojo_opset/tests/accuracy/operators/test_attention.py:1603-1617 -- Both the S=1 and S>1 entries reuse the same ID strings (e.g. two M_BF16_LONG), which produces duplicate pytest node ids and makes failures ambiguous. Suffix with _S{n}.

Nits

Nits (2)
  • [NIT] if (seq_len != 1): -- mojo_opset/tests/accuracy/operators/test_attention.py:43 -- drop parentheses; idiomatic Python.
  • [NIT] Trailing blank line removed at end of swa_paged_decode_impl is gratuitous churn -- mojo_opset/backends/ttx/kernels/mlu/swa.py:981.

Notes

  • [CHECK] Q_SEQLEN is now a tl.constexpr taken directly from runtime seq_lens; confirm the autotune/cache behavior is acceptable (each distinct seq_len triggers a recompile, up to 4 variants).
  • [CHECK] The kernel computes offs_s = tl.arange(0, Q_SEQLEN) with no masking on the seq dimension; verify that all callers actually fill all Q_SEQLEN query positions (no partial-batch padding) since OOB seq positions will be loaded and written without a mask.

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- The kernel passes Q strides for the output tensor, which will produce wrong results whenever Q and O have different layouts.

Summary

Adds multi-token (Q_SEQLEN > 1) support to the SWA paged decode kernel and reference implementation, generalizing query shape from [bsz, n_q_heads, head_dim] to optionally [bsz, seq_len, n_q_heads, head_dim]. Tests are extended to cover seq_len in {1..4}.

Must fix

  • [BLOCKER] Output strides replaced with query strides -- mojo_opset/backends/ttx/kernels/mlu/swa.py:960-963 -- The call site now passes stride_qb, stride_qs, stride_qh, stride_qd for the four stride_o* kernel params instead of o.stride(...). This produces wrong stores whenever o has a different layout (e.g. non-contiguous q via slicing/permute). Restore o.stride(0..3).
  • [BLOCKER] Reference forward writes wrong layout for ndim==3 -- mojo_opset/core/operators/attention.py:745-751 -- o_i is [n_q_heads, seq_len, head_dim] (per the bmm); for the 3D path you squeeze(1) and assign to o[i] whose shape is [n_q_heads, head_dim] -- fine. But in the 4D path you permute(1,0,2) to [seq_len, n_q_heads, head_dim] and assign to o[i] whose shape is [seq_len, n_q_heads, head_dim] -- also fine. However o = torch.zeros_like(query) means for ndim==4 o[i] is [seq_len, n_q_heads, head_dim]; double-check that o_i.permute is contiguous-compatible with the assignment (it is, but verify dtype cast ordering). Not a blocker by itself -- combined with the kernel issue above this path silently disagrees with the kernel. Please add a test that exercises non-contiguous q/o.
  • [BLOCKER] Hard-coded seq_lens <= 4 assertion without user-facing error -- mojo_opset/backends/ttx/kernels/mlu/swa.py:925 -- Bare assert with no message; will be stripped under -O and gives no guidance. Use an explicit if ... raise ValueError(...) and document the limit (also document where the 4 comes from -- compile-time tile size?).

Suggestions

Suggestions (4)
  • [MAJOR] Q_SEQLEN as constexpr forces recompile per seq_len -- mojo_opset/backends/ttx/kernels/mlu/swa.py:740 -- Each distinct seq_len triggers a new Triton compile; if callers hit varying lengths in decode, consider padding to a small fixed tile (e.g. next_pow2 up to 4) and masking on offs_s < Q_SEQLEN. Also currently there is no mask on offs_s, so any Q_SEQLEN not equal to the actual tile reads/writes OOB -- works only because you set it to the exact runtime value, which guarantees recompiles.
  • [MAJOR] stride_qs = 1 default for 3D inputs is arbitrary -- mojo_opset/backends/ttx/kernels/mlu/swa.py:917 -- When q.ndim == 3, Q_SEQLEN is 1 so offs_s = [0] and the stride is unused, but setting it to 1 is misleading; use 0 to make the intent explicit.
  • [MAJOR] Test ids duplicated -- mojo_opset/tests/accuracy/operators/test_attention.py:1603-1617 -- The same id= string (e.g. "M_BF16") is used for both S=1 and S=2 rows; pytest will append a numeric suffix but this hurts log readability. Encode S into the id.
  • [MINOR] Re-enabled previously commented-out cases without explanation -- mojo_opset/tests/accuracy/operators/test_attention.py:1603-1606 -- M_BF16_BIGPAGE and M_BF16_GROUP1 were commented out before; if they were disabled for a reason, please confirm in the PR description that the cause is fixed.

Nits

Nits (2)
  • [NIT] Trailing blank line removed and final blank line removed -- mojo_opset/backends/ttx/kernels/mlu/swa.py:790, 981 -- pure churn.
  • [NIT] Parenthesized condition if (seq_len != 1): -- mojo_opset/tests/accuracy/operators/test_attention.py:42 -- drop parens for PEP 8.

Notes

  • [CHECK] Confirm the kernel correctly handles Q_SEQLEN > 1 with the causal/window mask -- the mask currently does not appear to vary across the seq_len dimension, so all query positions in the tile may share the same mask, which is incorrect if positions correspond to different absolute offsets.
  • [CHECK] assert seq_lens <= 4 -- verify whether 4 is a kernel tile constraint or just what tests cover; if the latter, the limit should be removed or raised.

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Output strides are passed as Q strides in the kernel call, which will corrupt writes whenever Q and O have different layouts.

Summary

Adds support for a small Q sequence length dimension (up to 4) in the SWA paged decode path, both in the Triton kernel and the reference operator. Tests are extended to cover both 3D and 4D query shapes.

Must fix

  • [BLOCKER] Output strides replaced with Q strides -- mojo_opset/backends/ttx/kernels/mlu/swa.py:960-963 -- The call now passes stride_qb, stride_qs, stride_qh, stride_qd for the o strides (stride_ob/os/oh/od). Should be o.stride(0..3) (handling the 3D vs 4D case the same way as Q), otherwise stores go to wrong addresses whenever o has a different layout than q (e.g. torch.zeros_like of a non-contiguous Q, or differing shapes).
  • [BLOCKER] stride_qs=1 for 3D path is wrong when used with Q_SEQLEN=1 -- mojo_opset/backends/ttx/kernels/mlu/swa.py:914-917 -- For the 3D fallback stride_qs is hardcoded to 1, but offs_s[:, None, None] * stride_qs with Q_SEQLEN=1 only multiplies by 0 so it happens to work; however stride_qs=1 collides with stride_qd (also 1 for contiguous tensors) and is misleading. Set it to 0 (or any value, since offs_s is [0]) and document, or better, reshape Q to 4D and use real strides.

Suggestions

Suggestions (4)
  • [MAJOR] Silent assertion limit -- mojo_opset/backends/ttx/kernels/mlu/swa.py:925 -- assert seq_lens <= 4 is an undocumented hard limit; raise a clear error (or expose a constant) and document it in MojoPagedDecodeSWA.forward.
  • [MAJOR] Reference path duplicates branch -- mojo_opset/core/operators/attention.py:748-752 -- The two branches differ only in squeeze(1); collapse to one assignment after a single conditional permute/squeeze to reduce drift.
  • [MINOR] Q_SEQLEN as tl.constexpr forces recompilation per seq_len -- mojo_opset/backends/ttx/kernels/mlu/swa.py:740 -- Acceptable given the <=4 cap, but worth noting; consider padding to a fixed power of two with a mask to limit kernel variants.
  • [MINOR] Test ids duplicated -- mojo_opset/tests/accuracy/operators/test_attention.py:1610-1618 -- Several configs reuse the same ID string (e.g. M_BF16, M_BF16_PADDIM) across different S, which makes pytest ids ambiguous; suffix with _S{S}.

Nits

Nits (3)
  • [NIT] Trailing-whitespace removal mixed with logic change -- mojo_opset/backends/ttx/kernels/mlu/swa.py:790,981 -- fine, just noise.
  • [NIT] Typo in comment -- mojo_opset/tests/accuracy/operators/test_attention.py:1619 -- "esle" -> "else".
  • [NIT] Parenthesized condition -- mojo_opset/tests/accuracy/operators/test_attention.py:42 -- if (seq_len != -1): -- drop parens, prefer is not for sentinel or use Optional[int] = None.

Notes

  • [CHECK] Confirm tl.reshape on MLU/Triton variant in use preserves the intended [S, H, D] -> [S*H, D] row-major ordering; if not, the per-head softmax stats m_i/l_i will be mismatched against acc rows.
  • [CHECK] Verify torch.zeros_like(query) produces the same stride layout the kernel assumes; if query is a non-contiguous view (common for transposed KV-cache feeds), the output stride bug above will be observable.

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- The kernel does not handle the q-seqlen dimension within the attention computation (no per-token masking against kv positions), which will produce incorrect results for seq_len > 1.

Summary

Adds support for a query seq_len dimension (up to 4) to the SWA paged decode kernel and reference implementation, with an updated test matrix. The kernel reshapes Q to [Q_SEQLEN * BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D] and runs the existing single-token attention loop over all query tokens at once.

Must fix

  • [BLOCKER] Per-token causal/window masking missing in kernel -- mojo_opset/backends/ttx/kernels/mlu/swa.py:771-895 -- After flattening Q to [Q_SEQLEN*HEADS, D], the kv loop and masking logic still treat all query tokens as the single position kv_seq_len-1. For Q_SEQLEN > 1, each query token corresponds to a different absolute position (kv_seq_len - Q_SEQLEN + s) and needs its own causal/local-window mask; the current code will attend to future kv tokens and apply the wrong window, producing results that diverge from the reference (which masks per query row).
  • [BLOCKER] assert seq_lens <= 4 is silently ignored for 3D input -- mojo_opset/backends/ttx/kernels/mlu/swa.py:932 -- For 3D inputs seq_lens=1 so this is fine, but there is no check that Q_SEQLEN is a power of two / matches kernel constraints, and no informative error. More importantly the limit of 4 is a hard kernel constraint that is undocumented at the operator API layer (MojoPagedDecodeSWA.forward); add the same assertion and a comment there so callers fail early with context.
  • [BLOCKER] Output stride for non-contiguous outputs -- mojo_opset/backends/ttx/kernels/mlu/swa.py:911-925 -- stride_ob/os/oh/od are computed from the input shape assuming a contiguous output, but o = torch.empty_like(...) / the caller's output tensor strides are not consulted (compare to the prior code that used o.stride(...)). If o is ever non-contiguous this will corrupt memory; use o.stride() instead of recomputing.

Suggestions

Suggestions (3)
  • [MAJOR] Reference impl reshape may be wrong for 4D path -- mojo_opset/core/operators/attention.py:701-704 -- q_i = query[i].permute(1,0,2) gives [n_q_heads, seq_len, head_dim], which is then bmm'd with k_i_T of shape [n_q_heads, head_dim, kv_seq_len]. That works, but the per-row causal mask in _generate_window_mask must be generated for seq_len query rows aligned to the last seq_len kv positions; verify _generate_window_mask is being called with the correct q_len/offset for the 4D case.
  • [MAJOR] Divide-by-zero guard changes semantics -- mojo_opset/backends/ttx/kernels/mlu/swa.py:894 -- The new tl.where(l_i>0, acc/where(l_i>0, l_i, 1), 0) runs unconditionally, replacing the prior if kv_seq_len > 0 guarded division. For rows fully masked out (e.g. local window with no valid kv) l_i can be 0 even when kv_seq_len > 0; good. But the outer if kv_seq_len > 0 block is now dead-ish - acc stays zero anyway when l_i==0. Consider removing the outer if for clarity, or keeping it and dropping the inner where.
  • [MINOR] stride_qs=1 placeholder for 3D -- mojo_opset/backends/ttx/kernels/mlu/swa.py:917 -- For 3D input Q_SEQLEN=1 so offs_s=[0] and the stride is unused, but setting it to 1 is misleading; use 0 to make intent explicit.

Nits

Nits (3)
  • [NIT] Stray trailing-whitespace removal and removed blank line at mojo_opset/backends/ttx/kernels/mlu/swa.py:794 and :989 are unrelated churn.
  • [NIT] Comment typo esle -> else -- mojo_opset/tests/accuracy/operators/test_attention.py:1622.
  • [NIT] Duplicate test ids (e.g. two M_BF16, two M_BF16_PADDIM) -- mojo_opset/tests/accuracy/operators/test_attention.py:1603-1618 -- pytest will disambiguate but the IDs become non-unique and harder to grep.

Notes

  • [CHECK] mojo_opset/backends/ttx/kernels/mlu/swa.py:777 -- tl.reshape(q, [Q_SEQLEN*BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D]) assumes the GQA-head dimension is the fastest non-D axis after seq; confirm this matches the downstream qk/qv matmul layout for Q_SEQLEN > 1 (i.e. that mixing seq and head into one M dim is what the rest of the loop expects).
  • [CHECK] mojo_opset/backends/ttx/kernels/mlu/swa.py:935 -- the assert seq_lens <= 4 upper bound: is this a kernel tiling limit (pipeline_strategies / num_warps=1) or arbitrary? Document in a comment.

@Neuromancer42

Copy link
Copy Markdown
Collaborator

Verdict: Request changes -- The kernel does not handle the q-seqlen dimension within the attention computation (no per-token masking against kv positions), which will produce incorrect results for seq_len > 1.

注意这个问题,decode N个token应当产生[N, kv_len]的causal mask,而非[1, kv_len]

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- MLU kernel masking and reshape logic for the new MTP path looks incorrect; please verify.

Summary

Adds support for 4D (MTP) query input [bsz, seq_len, n_q_heads, head_dim] to the SWA paged decode path, implementing it on MLU and explicitly raising NotImplementedError on ILU/NPU. Reference op in core is generalized and tests parameterize over seq_len.

Must fix

  • [BLOCKER] Masking missing for padded Q seq positions -- mojo_opset/backends/ttx/kernels/mlu/swa.py:781 -- Q_SEQLEN is the constexpr tile size (>= actual seq_lens, asserted <= 4) but there is no offs_s < seq_lens mask on the Q load or O store. If a caller passes any seq_len smaller than the constexpr tile (or non-power-of-2 padded), garbage Q is loaded and OOB stores occur. Add a seq mask, or pass actual seq_len and mask both load/store.
  • [BLOCKER] Causal/window mask does not account for Q sequence position -- mojo_opset/backends/ttx/kernels/mlu/swa.py:771-895 -- The reshape collapses [Q_SEQLEN, BLOCK_SIZE_Q_HEADS] into the head dimension and the existing scoring/masking logic (unchanged in diff) treats all rows as the same query position kv_seq_len-1. For MTP each of the seq_len query tokens has a different causal position, so the SWA mask is wrong for seq_len > 1. Compare against the reference in core/operators/attention.py which uses _generate_window_mask(q_i.shape[1], ...).
  • [BLOCKER] seq_lens passed as runtime tile size -- mojo_opset/backends/ttx/kernels/mlu/swa.py:968 and :737 -- Q_SEQLEN is tl.constexpr but the host passes the dynamic seq_lens from q.shape. tl.arange(0, Q_SEQLEN) requires a power-of-two; seq_lens=3 will fail to compile, and varying values cause recompiles. Round up to a constexpr tile (e.g. next_pow2) and mask.
  • [BLOCKER] ILU/NPU silently broken for MTP callers -- mojo_opset/backends/ttx/kernels/ilu/swa.py:2895, mojo_opset/backends/ttx/kernels/npu/swa.py:1277 -- Raising NotImplementedError is fine, but MojoPagedDecodeSWA.forward in core accepts 4D and dispatches without checking backend capability, so users on those backends will hit a hard error at kernel time. Either gate at the operator level or document/wire a fallback to the reference path.

Suggestions

Suggestions (3)
  • [MAJOR] Double tl.where on l_i -- mojo_opset/backends/ttx/kernels/mlu/swa.py:894 -- acc / tl.where(l_i>0, l_i, 1.0) is already safe; the outer tl.where(l_i>0, ..., 0.0) is redundant and changes behavior vs original (which divided unconditionally and was guarded by if kv_seq_len > 0). Keep it simple.
  • [MAJOR] assert seq_lens <= 4 is silent magic -- mojo_opset/backends/ttx/kernels/mlu/swa.py:933 -- No comment on why 4. Make this a named constant or raise with a clear message tied to MTP token budget.
  • [MINOR] Output stride computation ignores actual o layout -- mojo_opset/backends/ttx/kernels/mlu/swa.py:912-924 -- Strides are computed from shape assuming contiguous output, but o = torch.zeros_like(query) later (in the existing surrounding code path) may inherit query's non-contiguous strides. Use o.stride() after allocation, as the original did.

Nits

Nits (3)
  • [NIT] Typo in test comment "esle" -- mojo_opset/tests/accuracy/operators/test_attention.py:1626.
  • [NIT] Duplicate test ids ("M_BF16", "M_BF16_PADDIM", etc.) across 3D/4D rows -- mojo_opset/tests/accuracy/operators/test_attention.py:1603-1618 -- pytest will mangle them; add a _MTP{S} suffix.
  • [NIT] Trailing blank line removed and parens around if (seq_len != -1): -- mojo_opset/tests/accuracy/operators/test_attention.py:43.

Notes

  • [CHECK] MojoPagedDecodeSWA.forward reference: when query is 4D, o = torch.zeros_like(query) is [bsz, seq_len, n_q_heads, head_dim] and o[i] = o_i where o_i is [seq_len, n_q_heads, head_dim] -- looks right, but please verify dtype/contiguity matches the kernel path used in tests.
  • [CHECK] stride_qs=1 in the 3D fallback is harmless only if Q_SEQLEN==1 so the offset is always 0; confirm the kernel is never instantiated with Q_SEQLEN>1 on 3D inputs.

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- N-step SWA decode kernel has incorrect output strides for the 3D query path and a redundant double-where in the safe-divide.

Summary

Adds an n-step variant of paged SWA decode (query shape [bsz, seq_len, n_q_heads, head_dim]), wires it through the experimental operator registry, the TTX/MLU backend kernel, and adds reference + accuracy tests. The Triton kernel is generalized over a new Q_SEQLEN constexpr and reshapes Q/acc to fold seq_len into the head dimension.

Must fix

  • [BLOCKER] Output strides ignore the user's tensor layout (3D path) -- mojo_opset/backends/ttx/kernels/mlu/swa.py:911-925 -- For both 4D and 3D inputs, stride_o* are computed from shapes assuming a contiguous output, but o is allocated by the caller (torch.zeros_like(query) in the operator) and may not be contiguous (e.g. permuted query). Use o.stride(...) instead of recomputing from shape, otherwise stores will land at the wrong addresses for non-contiguous outputs.
  • [BLOCKER] Redundant/confusing safe-divide -- mojo_opset/backends/ttx/kernels/mlu/swa.py:894 -- tl.where(l_i>0, acc / tl.where(l_i>0, l_i, 1.0), 0.0) evaluates acc/l_i semantically but the inner where is the only one that prevents div-by-zero; the outer branch then masks valid results to 0 when l_i<=0, which differs from prior behavior (previously gated by kv_seq_len > 0). For multi-step queries some rows may be fully masked (causal/local window) producing l_i==0 legitimately; silently zeroing is fine but the double-where is dead code -- simplify to a single tl.where(l_i>0, acc / tl.where(l_i>0, l_i, 1.0), 0.0) outside the if kv_seq_len > 0 guard, or document why both layers are needed.
  • [BLOCKER] Q_SEQLEN <= 4 assertion is silent for callers -- mojo_opset/backends/ttx/kernels/mlu/swa.py:933 -- Bare assert seq_lens <= 4 with no message; if Python is run with -O this disappears and the kernel will allocate huge tiles or miscompile. Use an explicit if ...: raise ValueError(...) and state the actual constraint (where does 4 come from?).

Suggestions

Suggestions (4)
  • [MAJOR] seq_lens passed as runtime int but declared tl.constexpr -- mojo_opset/backends/ttx/kernels/mlu/swa.py:734,968 -- Q_SEQLEN: tl.constexpr will specialize/recompile per distinct seq_lens value. Confirm callers only use a small fixed set (the <=4 assertion suggests so) and consider documenting this to avoid recompile storms.
  • [MAJOR] Reference impl errors on padded rows when block_table[i,0] < 0 -- mojo_opset/experimental/operators/attention.py:1232-1233 -- The kernel path tolerates kv_seq_len==0 (skips work and writes zeros), but the reference raises ValueError. The kv_seq_len <= 0 branch above already continues, so this check is only hit with kv_seq_len>0 -- fine, but verify the test data generator never produces kv_seq_len>0 with block_table[i,0]<0 for the new n-step tests.
  • [MAJOR] assert_paged_decode_contract may not validate 4D query shape -- mojo_opset/backends/ttx/operators/attention.py:451 and mojo_opset/experimental/operators/attention.py:1213 -- Make sure the contract helper accepts the new 4D query layout, or add explicit shape validation here.
  • [MINOR] MojoPagedDecodeNstepSWA reuses the same kernel as decode-1 -- mojo_opset/backends/ttx/operators/attention.py:434-460 -- Consider whether a new public operator is justified versus extending MojoPagedDecodeSWA to accept either rank; duplicating the operator increases registry/maintenance surface.

Nits

Nits (4)
  • [NIT] Stray trailing-blank-line removal mixed with logic changes -- mojo_opset/backends/ttx/kernels/mlu/swa.py:790,989 -- unrelated whitespace churn.
  • [NIT] if (seq_len != -1): -- mojo_opset/tests/accuracy/operators/test_attention.py:43 -- drop parens; using -1 as a sentinel is brittle, prefer Optional[int] = None.
  • [NIT] Removed blank line between test_configs_swa_decode = [...] and @pytest.mark.parametrize -- mojo_opset/tests/accuracy/operators/test_attention.py:1612 -- minor formatting regression.
  • [NIT] Duplicate test ids in test_configs_swa_nstep_decode ("M_BF16" and "M_BF16_PADDIM" appear twice) -- mojo_opset/tests/accuracy/operators/test_attention.py:1682-1685 -- pytest will append 0/1 but better to encode S in the id.

Notes

  • [CHECK] The new tl.reshape(q, [Q_SEQLEN * BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D]) requires Triton to support reshape on a 3D loaded tile on the MLU backend; please confirm this lowers correctly for all Q_SEQLEN in {1,2,3,4} and the chosen BLOCK_SIZE_Q_HEADS.
  • [CHECK] With Q_SEQLEN>1 the existing causal/window masking inside the inner kv-block loops (not shown in diff) must mask per-query-step, not per-batch -- verify the unchanged mask code uses offs_s correctly, otherwise the kernel will compute wrong attention for n-step decode despite tests passing on small shapes.

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- The kernel's stride/shape handling for the n-step path has correctness concerns and the assertion seq_lens <= 4 with no error message is too restrictive/silent.

Summary

Adds an n-step paged-decode SWA path: a new MojoPagedDecodeNstepSWA operator (torch reference + ttx registration) and extends the existing Triton _swa_paged_decode_kernel to accept a [bsz, seq_len, n_q_heads, head_dim] query with causal/local/global masking adjusted for Q_SEQLEN > 1.

Must fix

  • [BLOCKER] Output strides ignore actual tensor layout -- mojo_opset/backends/ttx/kernels/mlu/swa.py:919-934 -- For both 4D and 3D paths, stride_ob/os/oh/od are computed from shape (assuming contiguous) instead of from o.stride(). o = torch.zeros_like(query) is contiguous today, but this is fragile and inconsistent with how Q strides are taken from q.stride(). Use o.stride().
  • [BLOCKER] Q_SEQLEN is passed as a runtime int but declared tl.constexpr -- mojo_opset/backends/ttx/kernels/mlu/swa.py:741, 965 -- seq_lens comes from q.shape[1] (a Python int, fine for constexpr) but in the 3D branch it is set to literal 1; ensure call sites always pass a Python int, and document that this triggers recompilation per seq_len. Also the assert seq_lens <= 4 at line ~941 has no message and silently caps a public API -- either raise a descriptive ValueError or remove the magic number.
  • [BLOCKER] Causal mask indexing assumes a specific token ordering -- swa.py:828-833, 875-878 -- seq_offsets = arange(BLOCK_SIZE_Q_HEADS * Q_SEQLEN) // BLOCK_SIZE_Q_HEADS makes row r correspond to query step r // BLOCK_SIZE_Q_HEADS, which only matches the reshape q -> [Q_SEQLEN*BLOCK_SIZE_Q_HEADS, D] (step-major, head-minor). Confirm and add a comment; if BLOCK_SIZE_Q_HEADS is ever > GQA_GROUP_SIZE the per-row mapping is still correct but the masked-out head rows will participate in softmax denominators -- verify l_i/m_i handling.
  • [BLOCKER] Double tl.where on l_i is redundant and hides the bug it tries to fix -- swa.py:901 -- tl.where(l_i>0, acc / tl.where(l_i>0, l_i, 1.0), 0.0) -- the inner where already prevents div-by-zero; the outer where additionally zeros valid rows where l_i is subnormal-but-positive only if l_i==0. Simplify to one tl.where, and note this differs from the previous behavior where fully-masked rows produced nan. If rows can be fully masked (e.g. Q_SEQLEN > kv_seq_len for a sample), the torch reference returns nan/-inf softmax, so outputs will diverge.

Suggestions

Suggestions (4)
  • [MAJOR] Reference op uses .item() per batch, breaks CUDA graphs / async -- mojo_opset/experimental/operators/attention.py:1226, 1229 -- total_seq_lens[i].item() and block_table[i,0].item() force host syncs in the torch reference; fine for tests but worth a comment so it is not used as a perf baseline.
  • [MAJOR] Silent skip of kv_seq_len <= 0 rows in reference but kernel writes zeros -- attention.py:1226-1228 vs swa.py:903 -- reference leaves o[i] as the pre-allocated zeros_like; kernel branches on kv_seq_len > 0 and otherwise leaves acc=0/stores zeros. OK, but both relying on torch.zeros_like initialization is implicit -- document it.
  • [MAJOR] assert seq_lens <= 4 is an undocumented hard limit -- swa.py:941 -- expose as a named constant or kwarg, and validate in the public MojoPagedDecodeNstepSWA rather than only in the ttx impl.
  • [MINOR] Test parametrization marker -- mojo_opset/tests/accuracy/operators/test_attention.py:1607 -- seq_len=-1 sentinel for "decode" is awkward; consider seq_len=None and branch explicitly.

Nits

Nits (3)
  • [NIT] Typo casul_mask -> causal_mask -- swa.py:833, 836, 877, 879, 881.
  • [NIT] Trailing whitespace after casul_mask line -- swa.py:881.
  • [NIT] Removed trailing newline at EOF of swa.py -- swa.py:998.

Notes

  • [CHECK] The new branch enables previously-commented test configs (M_BF16_BIGPAGE, M_BF16_GROUP1) for the existing test_paged_decode_swa -- confirm those were intentionally re-enabled and pass on MLU; they were commented out before for a reason.
  • [CHECK] stride_qs = 1 in the 3D fallback (swa.py:931) is only safe because offs_s = arange(0, 1) so the offset is 0; verify Triton does not emit a real stride-1 load that aliases head dim.

zhouronghai and others added 14 commits June 22, 2026 11:42
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- N-step SWA decode kernel has correctness bugs in masking and a misplaced division-by-zero guard, plus a hard-coded seq_len limit and typos.

Summary

Adds a new MojoPagedDecodeNstepSWA operator and a multi-step (n-step, seq_len > 1) extension to the Triton SWA paged-decode kernel for MLU. The kernel is generalized to handle a 4D query [bsz, seq_len, n_q_heads, head_dim], with new strides, causal/local-window masks indexed by per-row sequence offset, and a torch reference implementation plus tests.

Must fix

  • [BLOCKER] Causal mask uses wrong inequality direction -- mojo_opset/backends/ttx/kernels/mlu/swa.py:830,876 -- casul_mask = base - seq_offsets <= kv_seq_len - Q_SEQLEN always allows the last Q_SEQLEN positions for every query row. The correct causal condition for query row s (0-indexed within the step) is kv_pos <= kv_seq_len - Q_SEQLEN + s, i.e. base[None,:] <= kv_seq_len - Q_SEQLEN + seq_offsets[:,None] (no subtraction of seq_offsets on the LHS). As written, row 0 attends to tokens it must not see.
  • [BLOCKER] seq_offsets indexing assumes a head-minor layout -- mojo_opset/backends/ttx/kernels/mlu/swa.py:825,873 -- seq_offsets = arange(BLOCK_SIZE_Q_HEADS * Q_SEQLEN) // BLOCK_SIZE_Q_HEADS maps the flattened axis as [h0s0, h1s0, ..., hNs0, h0s1, ...], but q is reshaped from [Q_SEQLEN, BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D] whose flatten order is seq-major ([s0h0, s0h1, ...]). The seq index should be arange(...) // BLOCK_SIZE_Q_HEADS only if the inner axis is heads, which matches reshape order, but the m_i/l_i/acc semantics in _decode_acc_fwd_MxN assume per-row scaling -- please verify the row ordering matches between q, acc, and the mask; current code looks inconsistent with tl.reshape(q, [Q_SEQLEN * BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D]).
  • [BLOCKER] Division-by-zero guard moved out of the if kv_seq_len > 0 branch is fine, but the new form still divides under both branches of tl.where -- mojo_opset/backends/ttx/kernels/mlu/swa.py:901 -- the outer tl.where(l_i>0, acc / tl.where(l_i>0, l_i, 1.0), 0.0) is OK, but it executes only when kv_seq_len > 0; rows that received no unmasked keys (all -inf) still have l_i==0 and are now silently zeroed, which previously was undefined. Confirm this is intentional and that m_i==-inf rows do not produce NaNs in acc upstream (the existing _decode_acc_fwd_MxN may already produce NaN in acc when m_i is -inf).
  • [BLOCKER] Hard-coded assert seq_lens <= 4 -- mojo_opset/backends/ttx/kernels/mlu/swa.py:942 -- silently caps n-step decode at 4 with no error message and no documented rationale; either lift the limit, parameterize it, or raise a clear error explaining the constraint.

Suggestions

Suggestions (4)
  • [MAJOR] Local-window mask formula -- mojo_opset/backends/ttx/kernels/mlu/swa.py:828,875 -- base + LOCAL_WINDOW >= kv_seq_len - (Q_SEQLEN - seq_offsets) is equivalent to base + LOCAL_WINDOW >= kv_seq_len - Q_SEQLEN + seq_offsets; given the causal mask bug above, double-check this against the torch reference _generate_window_mask to ensure both endpoints match for every query row.
  • [MAJOR] Reference implementation skips zero-length rows but kernel writes zeros -- mojo_opset/experimental/operators/attention.py:1235-1238 -- the torch ref continues when kv_seq_len <= 0, leaving o[i] as the torch.zeros_like init; the Triton kernel also relies on o being pre-zeroed. The wrapper at swa.py:983 allocates o -- confirm it is zero-initialized (the original used torch.empty-like; n-step path should use torch.zeros to match).
  • [MINOR] casul_mask typo -- mojo_opset/backends/ttx/kernels/mlu/swa.py:830,876 -- spell as causal_mask.
  • [MINOR] seq_lens is a tensor-or-int polymorphism -- mojo_opset/backends/ttx/kernels/mlu/swa.py:919-933 -- seq_lens is reused as both the python int Q_SEQLEN constexpr and the wrapper's local; rename one to avoid confusion with total_seq_lens.

Nits

Nits (3)
  • [NIT] Trailing whitespace on new mask line -- mojo_opset/backends/ttx/kernels/mlu/swa.py:881.
  • [NIT] Removed trailing newline before return o -- mojo_opset/backends/ttx/kernels/mlu/swa.py:998.
  • [NIT] Test parametrize tuple width changed; the if (seq_len != -1) parens are un-pythonic -- mojo_opset/tests/accuracy/operators/test_attention.py:43.

Notes

  • [CHECK] _decode_acc_fwd_MxN is called unchanged but now receives M = Q_SEQLEN * BLOCK_SIZE_Q_HEADS and a 2D mask [M, BLOCK_SIZE_N]; verify it broadcasts the mask along M correctly and that softmax stats are tracked per row of the flattened axis.
  • [CHECK] stride_ob/os/oh/od are computed from shape rather than o.stride() in the 4D path; this assumes o is contiguous. Confirm torch.zeros_like(query) preserves contiguity when query is non-contiguous (e.g. a permuted view from the caller).

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- N-step SWA decode kernel has incorrect mask arithmetic in the non-global window branch and a few correctness issues that will produce wrong outputs.

Summary

Adds an n-step variant of paged decode SWA that accepts a 4D query [bsz, seq_len, n_q_heads, head_dim], threading a new Q_SEQLEN constexpr and seq strides through the Triton kernel. Also adds a torch reference operator MojoPagedDecodeNstepSWA and tests.

Must fix

  • [BLOCKER] Wrong causal/sw mask indexing in non-global branch -- mojo_opset/backends/ttx/kernels/mlu/swa.py:875-881 -- seq_offsets = arange(BLOCK_SIZE_Q_HEADS * Q_SEQLEN) // BLOCK_SIZE_Q_HEADS produces per-head indices, but acc/m_i/l_i are laid out as [Q_SEQLEN * BLOCK_SIZE_Q_HEADS, ...] (i.e. seq is the outer dim, see the reshape at line ~786). The seq index for row r is r // BLOCK_SIZE_Q_HEADS only if seq is outer; verify against the reshape order, and also check the global-window branch which uses the same expression. Additionally, casul_mask = base - seq_offsets <= kv_seq_len - Q_SEQLEN is wrong: causal should be base <= kv_seq_len - Q_SEQLEN + seq_offsets (matches the global branch). These two diverging formulas cannot both be correct.
  • [BLOCKER] Inconsistent sw_mask formula between branches -- mojo_opset/backends/ttx/kernels/mlu/swa.py:872 -- Global branch uses base + LOCAL_WINDOW >= kv_seq_len - Q_SEQLEN + seq_offsets; non-global branch uses base + LOCAL_WINDOW >= kv_seq_len - (Q_SEQLEN - seq_offsets) which equals base + LOCAL_WINDOW >= kv_seq_len - Q_SEQLEN + seq_offsets only by coincidence of signs -- please make them identical and add a comment, since the two-form discrepancy is a bug magnet.
  • [BLOCKER] assert seq_lens <= 4 is silent and arbitrary -- mojo_opset/backends/ttx/kernels/mlu/swa.py:942 -- Hard cap with no error message and no documentation; either raise a clear ValueError explaining the constraint (and why 4) or remove. Also, when q.ndim == 3, seq_lens = 1 is passed as Q_SEQLEN constexpr which makes tl.arange(0, 1) legal but Q_SEQLEN must be a power of 2 for many Triton ops -- confirm 3D path still works.
  • [BLOCKER] Stride computation assumes contiguous output -- mojo_opset/backends/ttx/kernels/mlu/swa.py:919-932 -- stride_ob/os/oh/od are computed from shapes rather than from o.stride(). o = torch.zeros_like(query) is contiguous today, but this is fragile; just use o.stride() like before. Same comment applies to the 3D branch.
  • [BLOCKER] repeat_interleave semantics for GQA in torch reference -- mojo_opset/experimental/operators/attention.py:1244-1252 -- Verify the interleave/non-interleave mapping matches the kernel's GQA_INTERLEAVE convention; the existing MojoPagedDecodeSWAWithKVDequant should be the source of truth, please cross-check rather than re-deriving.

Suggestions

Suggestions (5)
  • [MAJOR] Typo casul_mask -> causal_mask -- mojo_opset/backends/ttx/kernels/mlu/swa.py:830,876 -- Will outlive this PR; fix now.
  • [MAJOR] Double tl.where for division-by-zero guard is redundant -- mojo_opset/backends/ttx/kernels/mlu/swa.py:901 -- tl.where(l_i>0, acc / tl.where(l_i>0, l_i, 1.0), 0.0) -- the inner where already prevents div-by-zero; the outer where is fine but the nested form is confusing. Simplify to safe_l = tl.where(l_i > 0, l_i, 1.0); acc = acc / safe_l (zero-init acc already gives 0 when l_i==0).
  • [MAJOR] Per-batch Python loop in torch reference -- mojo_opset/experimental/operators/attention.py:1228-1262 -- Fine for a reference, but total_seq_lens[i].item() and block_table[i, 0].item() force device sync per batch; acceptable for tests but please add a comment that this is a non-perf reference.
  • [MINOR] Assertion message references variable poorly -- mojo_opset/backends/ttx/kernels/mlu/swa.py:925-926 -- seqlens[torch.where(seqlens < seq_lens)] triggers another sync inside an assert even on the happy path due to f-string eager eval. Use a lazy form or split the check.
  • [MINOR] Public op exported but only mlu supported -- mojo_opset/backends/ttx/operators/attention.py:436 -- Confirm there is a torch fallback registration path so non-mlu users get a clean NotImplemented rather than an obscure failure.

Nits

Nits (3)
  • [NIT] Trailing whitespace and stray blank-line churn -- mojo_opset/backends/ttx/kernels/mlu/swa.py:820,872,884 -- mixed in with logic changes makes review harder.
  • [NIT] Removed trailing newline before return o -- mojo_opset/backends/ttx/kernels/mlu/swa.py:998 -- unrelated style change.
  • [NIT] if (seq_len != -1): in test helper -- mojo_opset/tests/accuracy/operators/test_attention.py:43 -- drop parens, PEP8.

Notes

  • [CHECK] The reshape q = tl.reshape(q, [Q_SEQLEN * BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D]) flattens seq as the outer dim, so all seq_offsets expressions must use r // BLOCK_SIZE_Q_HEADS. Please verify by printing/inspecting on a small case (e.g. Q_SEQLEN=2, BLOCK_SIZE_Q_HEADS=4) before merging.
  • [CHECK] The new tests reuse ids ("M_BF16", "M_BF16_PADDIM") for two different parametrize entries -- pytest may complain about duplicate ids.

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- N-step SWA decode kernel has a mask-layout bug and a typo that will produce wrong results for GQA.

Summary

Adds an n-step paged-decode SWA path: a new MojoPagedDecodeNstepSWA operator + reference, a TTX MLU Triton kernel extended to accept [bsz, seq_len, n_q_heads, head_dim] queries, and tests. The motivation appears to be supporting multi-token decode (speculative/MTP-style) with sliding-window attention.

Must fix

  • [BLOCKER] seq_offsets indexing assumes wrong head/seq layout -- mojo_opset/backends/ttx/kernels/mlu/swa.py:828, 877 -- q is reshaped as [Q_SEQLEN * BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D] from a tensor whose first two dims are (offs_s, offs_head_block), so the flattened row index is s * BLOCK_SIZE_Q_HEADS + h. Therefore seq_offsets = idx // BLOCK_SIZE_Q_HEADS is correct only when BLOCK_SIZE_Q_HEADS matches the actual GQA head block; for multi-head GQA (e.g. M_BF16_GROUP2) you must verify the divisor matches the inner stride or the causal/SW masks will be applied to the wrong query positions. Please add an assertion or fix to use the true inner dim.
  • [BLOCKER] Typo casul_mask -- mojo_opset/backends/ttx/kernels/mlu/swa.py:830, 833, 879, 881, 883 -- Misspelled identifier used across both branches; rename to causal_mask before this becomes load-bearing API/log noise.
  • [BLOCKER] Hard-coded assert seq_lens <= 4 -- mojo_opset/backends/ttx/kernels/mlu/swa.py:942 -- Magic limit with no explanation or error message; either document why (kernel tile constraint) and surface a clear error, or derive from a constexpr. As written it will silently break larger speculative-decode windows.

Suggestions

Suggestions (4)
  • [MAJOR] Double tl.where for div-by-zero guard is redundant -- mojo_opset/backends/ttx/kernels/mlu/swa.py:901 -- tl.where(l_i>0, acc / tl.where(l_i>0, l_i, 1.0), 0.0) evaluates the same predicate twice; one tl.where with a safe denominator suffices and is cheaper.
  • [MAJOR] Output stride computation ignores non-contiguous o -- mojo_opset/backends/ttx/kernels/mlu/swa.py:919-932 -- Strides for o are computed manually as if contiguous, but o = torch.zeros_like(query) inherits query's memory format. If query is non-contiguous (transposed view), the kernel will write to wrong addresses. Use o.stride() like before.
  • [MINOR] seq_lens shadows tensor name and is actually a scalar -- mojo_opset/backends/ttx/kernels/mlu/swa.py:917-933 -- Rename to q_seq_len to match the kernel's Q_SEQLEN and avoid confusion with total_seq_lens.
  • [MINOR] Reference impl .item() per-batch in a Python loop -- mojo_opset/experimental/operators/attention.py:1227-1231 -- Fine for a reference, but note it forces host sync each iteration; add a comment that this is intentionally a slow reference.

Nits

Nits (3)
  • [NIT] Trailing whitespace on mask = kv_mask[None, :] & casul_mask -- mojo_opset/backends/ttx/kernels/mlu/swa.py:884.
  • [NIT] Removed blank line before return o -- mojo_opset/backends/ttx/kernels/mlu/swa.py:998 -- unrelated style churn.
  • [NIT] Assert message interpolates a tensor inside an f-string indexed by torch.where -- mojo_opset/backends/ttx/kernels/mlu/swa.py:925-926 -- only useful when the assert fails; consider lazy formatting.

Notes

  • [CHECK] MojoPagedDecodeNstepSWA.forward uses total_seq_lens[i].item() and block_table[i,0].item() inside a Python loop -- confirm this reference is only used in tests, not on a hot path.
  • [CHECK] The new operator only registers a mlu backend; ensure bypass_not_implemented actually skips on other platforms in CI rather than silently passing.
  • [CHECK] Verify the kernel still compiles for Q_SEQLEN=1 (legacy decode path) since tl.reshape(..., [Q_SEQLEN * BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D]) semantics changed.

Comment thread mojo_opset/tests/accuracy/operators/test_attention.py Outdated
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- N-step SWA decode kernel has a likely-incorrect Q reshape that will scramble per-token/per-head Q rows and break correctness.

Summary

Adds an n-step (multi-token) variant of paged decode SWA: a new MojoPagedDecodeNstepSWA operator, a torch reference, a TTX/MLU registration, and an extended Triton kernel that accepts [bsz, seq_len, n_q_heads, head_dim] queries. The motivation is to support speculative decoding / multi-step decode with a sliding window.

Must fix

  • [BLOCKER] Q reshape order mismatches mask layout -- mojo_opset/backends/ttx/kernels/mlu/swa.py:782-786 -- q is loaded as [Q_SEQLEN, BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D] and reshaped to [Q_SEQLEN*BLOCK_SIZE_Q_HEADS, D], so the flattened row index is s*BLOCK_SIZE_Q_HEADS + h. But masks use seq_offsets = arange(...) // BLOCK_SIZE_Q_HEADS, which assumes the same row order; that is consistent only if accumulation also follows (s, h). Verify _decode_acc_fwd_MxN treats the first dim as (s outer, h inner) -- otherwise tokens get the wrong causal/SWA mask. Either way, the safer encoding is offs = s * BLOCK_SIZE_Q_HEADS + h everywhere; please add a comment and a test that varies seq_len > 1 with BLOCK_SIZE_Q_HEADS > 1 to lock it in.
  • [BLOCKER] Typo casul_mask propagated -- mojo_opset/backends/ttx/kernels/mlu/swa.py:830,876,880 -- Misspelled identifier used in the hot path; rename to causal_mask. Not a functional bug but it will be copy-pasted; fix before merge.
  • [BLOCKER] Silent seq_lens <= 4 constraint -- mojo_opset/backends/ttx/kernels/mlu/swa.py:941 -- Bare assert seq_lens <= 4 with no message and no documentation in MojoPagedDecodeNstepSWA. Either raise a clear ValueError explaining the kernel limit, or expose/parameterize it; right now callers will hit a cryptic AssertionError.
  • [BLOCKER] Output stride computation ignores tensor layout -- mojo_opset/backends/ttx/kernels/mlu/swa.py:919-933 -- stride_ob/os/oh/od are computed from shapes assuming contiguous output, but o = torch.empty_like(q) (or whatever the caller allocates) is not guaranteed contiguous, and previously the code used o.stride(...). Use o.stride() like before, or explicitly allocate a contiguous o and assert it. Otherwise non-contiguous outputs will be written to wrong addresses.

Suggestions

Suggestions (4)
  • [MAJOR] Double tl.where for l_i guard is redundant and obscures intent -- mojo_opset/backends/ttx/kernels/mlu/swa.py:901 -- tl.where(l_i>0, acc / tl.where(l_i>0, l_i, 1.0), 0.0) -- the inner where already prevents div-by-zero; the outer where then zeros valid rows where l_i==0 (all-masked rows). A single tl.where(l_i>0, acc / tl.where(l_i>0, l_i, 1.0), 0.0) is fine but please add a comment; previously the code guarded with if kv_seq_len > 0 only.
  • [MAJOR] Commented-out assert in dispatch -- mojo_opset/backends/ttx/kernels/mlu/swa.py:923-924 -- Dead # assert torch.all(seqlens >= seq_lens) should be either enabled (cheap on host int tensor) or removed; leaving it commented is debug residue.
  • [MINOR] _generate_window_mask reused for n-step reference -- mojo_opset/experimental/operators/attention.py:1234-1240 -- Make sure that helper produces a [seq_len, kv_seq_len] causal+window mask aligned to the last seq_len tokens of the kv (i.e. query positions are kv_seq_len - seq_len + i); the reference assumes this. Add an explicit comment or a small unit test for seq_len > 1.
  • [MINOR] Trailing blank line removed before kernel launch -- mojo_opset/backends/ttx/kernels/mlu/swa.py:997 -- Unrelated whitespace churn; revert to keep diff focused.

Nits

Nits (4)
  • [NIT] Indentation changed from 4 to 2 spaces inside if max_seq_len > 0 block -- mojo_opset/tests/accuracy/operators/test_attention.py:50-56 -- inconsistent with file style.
  • [NIT] Trailing whitespace after sw_mask = ... seq_offsets[:, None] -- mojo_opset/backends/ttx/kernels/mlu/swa.py:833.
  • [NIT] extra_repr of new class doesn't include all fields the sibling MojoPagedDecodeSWAWithKVDequant.extra_repr does -- mojo_opset/experimental/operators/attention.py:1262 -- keep them consistent.
  • [NIT] TTXPagedDecodeNstepSWA duplicates almost all of TTXSWA.forward; consider sharing via a helper -- mojo_opset/backends/ttx/operators/attention.py:436-463.

Notes

  • [CHECK] The kernel still uses kv_seq_len - 1 semantics in some places that were converted to kv_seq_len - Q_SEQLEN + seq_offsets; confirm _swa_split_blocks(kv_seq_len - Q_SEQLEN, Q_SEQLEN, ...) produces the correct non_global_window_start_block when the local window straddles the global window for the earliest query token, not just the last.
  • [CHECK] seq_lens is passed as a Python int via tl.constexpr Q_SEQLEN; ensure callers don't pass a tensor and that the value is JIT-cache-friendly (e.g. always 1/2/3/4).

@Neuromancer42 Neuromancer42 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Neuromancer42 Neuromancer42 merged commit 4dd96e1 into master Jun 25, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants