Skip to content

[KMCompiler][ttx]Optimize rms_norm for small cols#363

Open
YangLong114514 wants to merge 2 commits into
XPU-Forces:masterfrom
YangLong114514:KMCompiler-RmsNorm
Open

[KMCompiler][ttx]Optimize rms_norm for small cols#363
YangLong114514 wants to merge 2 commits into
XPU-Forces:masterfrom
YangLong114514:KMCompiler-RmsNorm

Conversation

@YangLong114514

Copy link
Copy Markdown

Description

The rms_norm operator has been optimized for the Ascend platform.

Changes

  1. Added _rmsnorm_infer_small_cols_kernel for n_cols <= 2048.

  2. Updated the BLOCK_SIZE_M selection logic for inference.

Performance

Using Ascend 910B and Triton 3.2.x of FlagTree,cann-8.5.0 :

shape Dtype befor after speedup
(1, 8, 128) float32 4.1424 2.1648 1.91
(4, 32, 256) float32 5.3104 3.1248 1.70
(8, 256, 2048) float32 25.9360 25.3568 1.02
(8, 256, 4096) float32 38.6848 38.6608 1.00
(1, 16, 4096) float32 6.6352 4.68 1.42
(4, 32, 2048) float32 5.944 5.4448 1.09

Accuracy test

platform linux -- Python 3.11.13, pytest-8.3.2, pluggy-1.6.0
rootdir: /data/baai_user_home/jstar/mojo-work/mojo_opset
configfile: pytest.ini
plugins: xdist-3.6.1, anyio-4.10.0
collected 10 items

test_normalization.py::test_rmsnorm[1e-05-dtype0-shape0] PASSED
test_normalization.py::test_rmsnorm[1e-05-dtype0-shape1] PASSED
test_normalization.py::test_rmsnorm[1e-05-dtype0-shape2] PASSED
test_normalization.py::test_rmsnorm[1e-05-dtype0-shape3] PASSED
test_normalization.py::test_rmsnorm[1e-05-dtype0-shape4] PASSED
test_normalization.py::test_rmsnorm[1e-05-dtype1-shape0] PASSED
test_normalization.py::test_rmsnorm[1e-05-dtype1-shape1] PASSED
test_normalization.py::test_rmsnorm[1e-05-dtype1-shape2] PASSED
test_normalization.py::test_rmsnorm[1e-05-dtype1-shape3] PASSED
test_normalization.py::test_rmsnorm[1e-05-dtype1-shape4] PASSED

============================ 10 passed in 24.16s ==========================

@YangLong114514 YangLong114514 changed the title [KMCompiler]Optimize rms_norm for small cols [KMCompiler][ttx]Optimize rms_norm for small cols Jun 15, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Triton kernel _rmsnorm_infer_small_cols_kernel to optimize RMSNorm inference for small column sizes, and updates the implementation to dynamically calculate BLOCK_SIZE_M and conditionally dispatch the appropriate kernel. The review feedback highlights critical issues where block sizes (BLOCK_SIZE_M and BLOCK_SIZE_N) may not be powers of two, which would cause Triton compilation failures, and provides actionable suggestions to resolve them.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread mojo_opset/backends/ttx/kernels/npu/rmsnorm.py Outdated
Comment thread mojo_opset/backends/ttx/kernels/npu/rmsnorm.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant