[KMCompiler][ttx] Optimize NPU ResidualAddRMSNorm forward performance#367
[KMCompiler][ttx] Optimize NPU ResidualAddRMSNorm forward performance#367YangLong114514 wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request optimizes the fused_add_rmsnorm Triton kernel on NPU by introducing a single-pass kernel for smaller column sizes, adding heuristics for row block sizes, and using compile-time constants to conditionally skip row/column masking and storing operations. The review feedback highlights three key improvement opportunities: ensuring the grid size is at least 1 to prevent runtime errors when n_rows is 0, using a cached utility to query NPU device properties instead of querying them on every call, and decoupling the row-masking optimization from the column-masking optimization to maximize performance gains.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Description
Optimize NPU ResidualAddRMSNorm forward performance by reducing intermediate tensor traffic, eliminating unnecessary masks, and tuning kernel scheduling for different shapes.
Changes
S = X + residualon-chip and directly compute the RMSNorm output in the single-pass pathStensor in post modeSTORE_RSTDswitch to control RSTD write-back while preserving the current training-compatible behaviorBLOCK_SIZE_Mbased onn_rowsandn_colsPerformance
Measured with
torch.float32inputs.Overall speedup: 1.25x–2.35x.
Post mode generally benefits more from eliminating the intermediate
Swrite, while large hidden dimensions benefit from scheduling andBLOCK_SIZE_Mtuning.Accuracy
Accuracy tests in
mojo_opset/tests/accuracy/operators/test_normalization.pypassed.