Skip to content

[KMCompiler][ttx] Optimize NPU ResidualAddRMSNorm forward performance#367

Open
YangLong114514 wants to merge 3 commits into
XPU-Forces:masterfrom
YangLong114514:KMCompiler-FuseAddRmsNorm
Open

[KMCompiler][ttx] Optimize NPU ResidualAddRMSNorm forward performance#367
YangLong114514 wants to merge 3 commits into
XPU-Forces:masterfrom
YangLong114514:KMCompiler-FuseAddRmsNorm

Conversation

@YangLong114514

Copy link
Copy Markdown

Description

Optimize NPU ResidualAddRMSNorm forward performance by reducing intermediate tensor traffic, eliminating unnecessary masks, and tuning kernel scheduling for different shapes.

Changes

  • Add a single-pass kernel when the hidden dimension fits in one tile
  • Keep S = X + residual on-chip and directly compute the RMSNorm output in the single-pass path
  • Skip allocating and writing the intermediate S tensor in post mode
  • Add a compile-time STORE_RSTD switch to control RSTD write-back while preserving the current training-compatible behavior
  • Add compile-time no-mask paths for row- and column-aligned shapes
  • Dynamically trim the grid based on the actual number of row tasks
  • Tune BLOCK_SIZE_M based on n_rows and n_cols
  • Use conservative row tiles for large hidden dimensions to avoid UB overflow
  • Preserve FP32 accumulation for RMS reduction
  • Retain the original multi-pass path for large hidden dimensions

Performance

Measured with torch.float32 inputs.

Shape Mode Before [us] After [us] Speedup
(8, 8192) pre 11.7296 5.7968 2.02x
(8, 8192) post 12.0592 5.2912 2.28x
(16, 4096) pre 9.0768 4.2528 2.13x
(16, 4096) post 9.3024 3.9584 2.35x
(32, 128) pre 5.3440 4.2912 1.25x
(32, 128) post 5.8528 2.6288 2.23x
(32, 2048) pre 7.3440 3.7200 1.97x
(32, 2048) post 7.3264 3.7232 1.97x
(128, 128) pre 6.3632 2.7936 2.28x
(128, 128) post 6.3520 2.8368 2.24x
(128, 2048) pre 8.0736 5.9200 1.36x
(128, 2048) post 7.9728 6.3888 1.25x
(256, 512) pre 7.2672 4.5280 1.60x
(256, 512) post 6.7136 4.5040 1.49x
(1024, 1024) pre 13.9440 11.1152 1.25x
(1024, 1024) post 13.9856 10.6720 1.31x

Overall speedup: 1.25x–2.35x.

Post mode generally benefits more from eliminating the intermediate S write, while large hidden dimensions benefit from scheduling and BLOCK_SIZE_M tuning.

Accuracy

Accuracy tests in mojo_opset/tests/accuracy/operators/test_normalization.py passed.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the fused_add_rmsnorm Triton kernel on NPU by introducing a single-pass kernel for smaller column sizes, adding heuristics for row block sizes, and using compile-time constants to conditionally skip row/column masking and storing operations. The review feedback highlights three key improvement opportunities: ensuring the grid size is at least 1 to prevent runtime errors when n_rows is 0, using a cached utility to query NPU device properties instead of querying them on every call, and decoupling the row-masking optimization from the column-masking optimization to maximize performance gains.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread mojo_opset/backends/ttx/kernels/npu/fused_add_rmsnorm.py Outdated
Comment thread mojo_opset/backends/ttx/kernels/npu/fused_add_rmsnorm.py Outdated
Comment thread mojo_opset/backends/ttx/kernels/npu/fused_add_rmsnorm.py Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant