Skip to content

switch to triton ascend 3.2.1#281

Open
kevin-hongkai wants to merge 34 commits into
masterfrom
hongkai/switch-to-triton-ascend
Open

switch to triton ascend 3.2.1#281
kevin-hongkai wants to merge 34 commits into
masterfrom
hongkai/switch-to-triton-ascend

Conversation

@kevin-hongkai

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request migrates NPU Triton kernels to support Triton Ascend 3.2.1, primarily by updating slicing, math, and hint operations to the tl.extra.cann namespace and enabling NaN propagation in reduction operations. Feedback identifies several critical issues, including potential integer wrap-around in quantization due to the removal of saturation modes, the accidental deletion of the byted-triton-x dependency in pyproject.toml, and the presence of dead code. Additionally, the reviewer recommends implementing safe division patterns in the SDPA kernel and suggests making the hardcoded normalization casting mode configurable.

Comment thread mojo_opset/backends/ttx/kernels/npu/quant.py
Comment thread mojo_opset/backends/ttx/kernels/npu/sdpa.py
Comment thread pyproject.toml
Comment thread mojo_opset/backends/ttx/functions/normalization.py
Comment thread mojo_opset/backends/ttx/kernels/npu/over_encoding/n_gram.py Outdated
@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Migration to vendor-namespaced Triton extensions looks mechanical and correct, but a few items (a perf-killing perf call removed, a dropped overflow_mode, tl.range vs tl.static_range, and propagate_nan on integer reductions) need attention before merge.

Summary

This PR migrates kernels from generic tl.extract_slice/tl.insert_slice/tl.compile_hint/tl.flip/tl.get_element to the vendor-specific tl.extra.cann.extension.* API and adds propagate_nan=tl.PropagateNan.ALL to tl.max/tl.maximum calls, matching the new byted-triton-x>=3.2.1 / CANN image. It also tweaks a few host-side tests, an indexer call signature, and adds enable_ubuf_saving/sync_solver kernel kwargs.

Must fix

  • [BLOCKER] Quant lost saturation on int8 cast -- mojo_opset/backends/ttx/kernels/npu/quant.py:154 -- Removing overflow_mode="saturate" means out-of-range values now wrap on cast to int8, which silently corrupts quantized outputs whenever the rounding +/- 0.5 pushes a value past +/-127. Either keep saturation or clamp explicitly before cast.
  • [BLOCKER] Indexer quant call: dropped scale arg + unconditional squeeze -- mojo_opset/experimental/operators/indexer.py:92-94 -- self.quant(q) drops the previous explicit None scale (API change not visible in this diff) and then q_scale.squeeze(-1) is unconditional, unlike the conditional k_scale.dim()==3 branch right below. If q_scale is not 3-D this will silently change shape; please mirror the conditional handling and confirm the new quant signature.
  • [BLOCKER] propagate_nan applied to integer reductions -- mojo_opset/backends/ttx/kernels/npu/sample.py:285,288 -- y_idx * (1 - mask) is an integer index tensor; tl.PropagateNan.ALL on integer tl.max is meaningless at best and may be rejected by the compiler depending on the version. Drop propagate_nan on these two calls.
  • [BLOCKER] tl.range vs tl.static_range over a runtime-bounded loop with constexpr indexing -- mojo_opset/backends/ttx/kernels/npu/over_encoding/fused_over_encoding.py:120-123 and .../n_gram.py analog -- The loop uses tl.static_range if BLOCK_BATCH_SIZE < 4 else tl.range but the body calls tl.extra.cann.extension.get_element(n_gram_ids, (ele_idx,)), which typically requires a compile-time index. If tl.range is taken, ele_idx is dynamic and codegen will fail or silently degrade. Please verify and force tl.static_range, or switch to a dynamic indexing primitive.
  • [BLOCKER] CI workflow uninstalls/reinstalls Triton imperatively -- .github/workflows/ascend_accuracy_ci.yml:59-61 -- rm -rf /usr/local/lib/python3.11/dist-packages/triton after pip uninstall is brittle (path is Python-version-specific and races with the pip metadata it just removed). Prefer pinning byted-triton-x in the image or using pip install --force-reinstall without manual rm.

Suggestions

Suggestions (5)
  • [MAJOR] Removed reference perf measurement -- mojo_opset/tests/perf/test_indexer.py:32-33 -- Dropping indexer_ref perf removes the only baseline comparison; if intentional (registry no longer has torch impl) please leave a comment, otherwise restore it.
  • [MAJOR] Removed tl.device_print is good, but check it was the only diagnostic -- mojo_opset/backends/ttx/kernels/npu/over_encoding/n_gram.py:188 -- Confirm no other debug prints remain in this file family; this one was clearly residue.
  • [MAJOR] enable_ubuf_saving=True / sync_solver=False are vendor flags with no comment -- mojo_opset/backends/ttx/kernels/npu/swa.py:633,1027, group_gemm.py:227,327 -- Add a one-line comment explaining the perf/correctness trade-off; otherwise future readers will not know whether they can flip them.
  • [MAJOR] propagate_nan semantics: confirm intent -- multiple files -- For softmax-style code tl.max(qk, ...) followed by exp(qk - m), propagating NaN turns one NaN into all-NaN output, which may be desired for masking debug but differs from prior behaviour where masked -inf rows produced -inf max. Please confirm this is the intended semantics under the new compiler.
  • [MINOR] Test now allocates on device but sizes differ from old shapes_label_cache/shapes_key_lr lists -- mojo_opset/tests/perf/test_store_lowrank.py:25-31 -- Module-level shapes_* and slot_mappings are now dead; remove them to avoid confusion.

Nits

Nits (5)
  • [NIT] Many added kwargs lack a space after the comma, e.g. propagate_nan=tl.PropagateNan.ALL) after axis=1, -- run formatter -- flash_attention.py:426, sample.py:285, swa.py:219, etc.
  • [NIT] Trailing-newline removed in two files -- .github/workflows/ascend_accuracy_ci.yml:119, mojo_opset/tests/perf/test_store_lowrank.py:47 -- restore final newline.
  • [NIT] Stray unrelated whitespace-only edits -- mojo_opset/core/functions/normalization.py:39, mojo_opset/tests/accuracy/functions/test_normalization.py:42, mojo_opset/backends/ttx/kernels/npu/sdpa.py:805 -- drop from this PR.
  • [NIT] mask=block_mask_1 is shaped [1, BLOCK_SIZE_N] -- the prior tl.view(block_mask, (1, BLOCK_SIZE_N)) was equivalent; new variable name block_mask_1 is unclear, consider block_mask_2d -- n_gram.py:167.
  • [NIT] pytest >= 8.4.0 -> >= 8.3.2 is a downgrade -- please note in the PR description why this is required (likely the new CI image), otherwise it looks accidental -- pyproject.toml:16.

Notes

  • [CHECK] tl.extra.cann.extension.compile_hint(b, "dot_pad_only_k") is called on b after b = tl.trans(b) in _m_grouped_matmul_bKmajor_kernel -- confirm the hint applies to the transposed value and not the pre-trans load (was the same in the old code, but worth re-verifying with the new extension semantics) -- group_gemm.py:113.
  • [CHECK] sdpa.py:85 changes the compile hint from "tile_cube_loop" to "hivm.tile_mix_cube_num" with value 2. Verify the new hint name and value are equivalent under the new triton-x; otherwise decode/prefill perf may regress silently.
  • [CHECK] causal_conv1d_update_kernel_bdt_fwd now casts out_block to x_ptr.dtype.element_ty before storing under SILU_ACTIVATION only -- confirm the non-SILU path already stores the correct dtype; otherwise this is asymmetric -- convolution.py:742.

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Several concerning changes: a default RMSNorm casting mode flip, a removed overflow_mode="saturate" on int8 quant cast, and a perf test that no longer compares against reference.

Summary

This PR adapts kernels to a new byted-triton-x (CANN extension namespaced ops, propagate_nan requirement on reductions/maximum) and updates the NPU CI image. It also tweaks a few host-side functions and tests.

Must fix

  • [BLOCKER] RMSNorm default casting_mode changed from "llama" to "gemma" -- mojo_opset/backends/ttx/functions/normalization.py:26 -- This silently changes numerical behavior of RMSNorm for all callers (gemma casts weight differently). Either revert or expose the field on the base class instead of flipping the hardcoded default.
  • [BLOCKER] int8 quant cast lost overflow_mode="saturate" -- mojo_opset/backends/ttx/kernels/npu/quant.py:154 -- Without saturation, out-of-range values wrap; the manual +/- 0.5 rounding can still exceed [-127,127] (e.g. NaNs/inf or scale=0 paths), producing wrong int8 values. Restore saturate or guarantee saturation upstream.
  • [BLOCKER] Perf test dropped reference comparison -- mojo_opset/tests/perf/test_indexer.py:32 -- Removing the torch reference call makes the perf test purely one-sided; intentional? If yes, rename/document; otherwise restore indexer_ref perf so regressions vs reference remain detectable.
  • [BLOCKER] compile_hint string changed: "tile_cube_loop" -> "hivm.tile_mix_cube_num" -- mojo_opset/backends/ttx/kernels/npu/sdpa.py:85 -- All other call sites kept their original hint strings (e.g. "dot_pad_only_k"); only this one was renamed. Confirm this is the intended new hint name and not a copy-paste; a wrong hint silently degrades performance on a hot decode path.
  • [BLOCKER] MojoIndexer.quant call signature changed and result reshaped -- mojo_opset/experimental/operators/indexer.py:92-94 -- Switched from quant(q, None) to quant(q) plus a new q_scale.squeeze(-1); if quant still returns a 3D scale this squeeze is a no-op or wrong depending on shape. Add a shape assertion or align with the k_scale.dim()==3 branch handling.

Suggestions

Suggestions (5)
  • [MAJOR] Stray device_print removal but no test for the previous output -- mojo_opset/backends/ttx/kernels/npu/over_encoding/n_gram.py:188 -- Good removal of debug residue; verify there is a test exercising MTP_STEP>1 path that previously relied on the printed value being incidental.
  • [MAJOR] sync_solver=False added to grouped matmul launches -- mojo_opset/backends/ttx/kernels/npu/group_gemm.py:227,327 -- This changes synchronization semantics; please document why it is safe (no cross-core dependency) and add a comment in the kernel.
  • [MAJOR] enable_ubuf_saving=True added unconditionally -- mojo_opset/backends/ttx/kernels/npu/swa.py:633,1027 -- Confirm this does not regress shapes that previously fit in UB; consider gating on tile size.
  • [MINOR] Tests changed from CPU tensors to device tensors -- mojo_opset/tests/perf/test_store_lowrank.py:36-44 -- New test no longer parameterizes over the original slot_mappings set; coverage of small kv_lens (1, 24) was dropped.
  • [MINOR] int64 -> int32 change in test util -- mojo_opset/tests/perf/test_linear.py:22 -- Silent dtype change; ensure the kernel under test accepts int32 group sizes (likely yes, but worth a comment).

Nits

Nits (5)
  • [NIT] Missing newline at EOF -- .github/workflows/ascend_accuracy_ci.yml:119 and mojo_opset/tests/perf/test_store_lowrank.py:47.
  • [NIT] Inconsistent spacing around kwargs propagate_nan=... (some have leading space, some not) -- e.g. mojo_opset/backends/ttx/kernels/npu/flash_attention.py:426 vs mojo_opset/backends/ttx/kernels/npu/sample.py:790.
  • [NIT] Commented-out m_ij = tl.maximum(...) left in -- mojo_opset/backends/ttx/kernels/npu/sdpa.py:69.
  • [NIT] Unrelated whitespace-only edits in core/functions/normalization.py:38 and tests/accuracy/functions/test_normalization.py:42 -- revert to keep diff minimal.
  • [NIT] pip install byted-triton-x>=3.2.1 should be quoted in shell to avoid redirection -- .github/workflows/ascend_accuracy_ci.yml:60.

Notes

  • [CHECK] All tl.extra.cann.extension.* and tl.extra.cann.math.tanh symbols exist in the pinned byted-triton-x>=3.2.1; if not, kernels will fail at compile time only when first invoked.
  • [CHECK] The propagate_nan=tl.PropagateNan.ALL change for tl.max/tl.maximum may alter softmax behavior when inputs contain NaN (previously NaNs may have been silently dropped); verify this matches intended semantics on attention masks set to -inf.

@kevin-hongkai

Copy link
Copy Markdown
Collaborator Author

Must fix

  • [BLOCKER] RMSNorm default casting_mode changed from "llama" to "gemma" -- mojo_opset/backends/ttx/functions/normalization.py:26 -- This silently changes numerical behavior of RMSNorm for all callers (gemma casts weight differently). Either revert or expose the field on the base class instead of flipping the hardcoded default.
    -----rms_norm 算子实现内有多个 cast_mode,默认走的 llama_cast_mode 存在形如 norm(x1.to(fp32)).to(x1.dtype) 的操作; 作为标杆的 torch 自带的 F.rms_norm 则是使用 fp32 高精度计算下来的,这会导致散点式的精度误差。当前可以将 cast_mode 配置为 gemma_mode 以规避该问题
  • [BLOCKER] int8 quant cast lost overflow_mode="saturate" -- mojo_opset/backends/ttx/kernels/npu/quant.py:154 -- Without saturation, out-of-range values wrap; the manual +/- 0.5 rounding can still exceed [-127,127] (e.g. NaNs/inf or scale=0 paths), producing wrong int8 values. Restore saturate or guarantee saturation upstream.
    ------不再需要添加overflow_mode="saturate",triton-ascend新版本也不再支持此字段的传入
  • [BLOCKER] Perf test dropped reference comparison -- mojo_opset/tests/perf/test_indexer.py:32 -- Removing the torch reference call makes the perf test purely one-sided; intentional? If yes, rename/document; otherwise restore indexer_ref perf so regressions vs reference remain detectable.
    ------torch在perf时容易卡死,耗时很常,没必要做torch性能
  • [BLOCKER] compile_hint string changed: "tile_cube_loop" -> "hivm.tile_mix_cube_num" -- mojo_opset/backends/ttx/kernels/npu/sdpa.py:85 -- All other call sites kept their original hint strings (e.g. "dot_pad_only_k"); only this one was renamed. Confirm this is the intended new hint name and not a copy-paste; a wrong hint silently degrades performance on a hot decode path.
  • [BLOCKER] MojoIndexer.quant call signature changed and result reshaped -- mojo_opset/experimental/operators/indexer.py:92-94 -- Switched from quant(q, None) to quant(q) plus a new q_scale.squeeze(-1); if quant still returns a 3D scale this squeeze is a no-op or wrong depending on shape. Add a shape assertion or align with the k_scale.dim()==3 branch handling.
    ------经测试,indexer中接口变更了,但是上层调用还是使用quant(q, None),会报错参数不匹配

Suggestions

Suggestions (5)

  • [MAJOR] Stray device_print removal but no test for the previous output -- mojo_opset/backends/ttx/kernels/npu/over_encoding/n_gram.py:188 -- Good removal of debug residue; verify there is a test exercising MTP_STEP>1 path that previously relied on the printed value being incidental.
    -------原来为了解决load在IR消失的问题,才加device_print;新版本triton-ascend已解决,可去除device_print
  • [MAJOR] sync_solver=False added to grouped matmul launches -- mojo_opset/backends/ttx/kernels/npu/group_gemm.py:227,327 -- This changes synchronization semantics; please document why it is safe (no cross-core dependency) and add a comment in the kernel.
    -----triton-ascend中groupgemm性能退化20%,是由于同步引起的,cube类不需要加同步,可使用sync_solver编译选项先去除同步; 后面版本的编译器已识别该问题,后续版本会解决性能退化问题。
  • [MAJOR] enable_ubuf_saving=True added unconditionally -- mojo_opset/backends/ttx/kernels/npu/swa.py:633,1027 -- Confirm this does not regress shapes that previously fit in UB; consider gating on tile size.
    -----需要添加enable_ubuf_saving,否则会ub overflow
  • [MINOR] int64 -> int32 change in test util -- mojo_opset/tests/perf/test_linear.py:22 -- Silent dtype change; ensure the kernel under test accepts int32 group sizes (likely yes, but worth a comment).
    ----测试发现,perf的用例test_linear.py 中int64类型会报错类型不匹配,改为int32修复

Notes

  • [CHECK] All tl.extra.cann.extension.* and tl.extra.cann.math.tanh symbols exist in the pinned byted-triton-x>=3.2.1; if not, kernels will fail at compile time only when first invoked.
    ------triton-ascend新版本中为兼容triton社区,NPU的所有扩展接口都改为tl.extra.cann.extention路径
  • [CHECK] The propagate_nan=tl.PropagateNan.ALL change for tl.max/tl.maximum may alter softmax behavior when inputs contain NaN (previously NaNs may have been silently dropped); verify this matches intended semantics on attention masks set to -inf.
    ------max添加propagate_nan=tl.PropagateNan.ALL可增加性能。编译器解释原因为:增加propagate_nan=tl.PropagateNan.ALL:底层的指令max针对nan的处理是 max(nan, a) = nan 或者 max(nan, nan)=nan,也就是 arith.maximumf: 传播 NaN:只要任一操作数是 NaN,结果就是 NaN, maximumf(1.0, NaN) = NaN的语义
    因此前端通过propagate nan来区分,propagate nan=propagate.ALL->默认生成arith.maximumf, 否则是arith.maxnumf, 因为底层指令不支持,所以会插入一些模拟指令处理边界

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant