Skip to content

Commit 781083b

Browse files
authored
[Optimize, NPU] Remove tl.where from _rms_norm_forward/backward_kernel_tiled() (#1153)
## Summary When the mask has a large shape, `tl.where` is not NPU-friendly in `triton-ascend`, leading to low kernel performance. When writing kernels, it's best to use alternative logic instead. This can result in significant performance improvements. **Will these changes affect accuracy?** Since the masking operation has already been applied when loading X_block, it will not affect the calculation result during tl.sum. ## Testing Done ### Accuracy first The shapes in `test_rms_norm.py` are too small to trigger the `_rms_norm_forward_kernel_tiled` kernel, so we need a new configuration. ```python @pytest.mark.parametrize( "bs, sl, hd", [ (2, 2048, 4096), (2, 2048, 8192), (2, 2048, 16384), (2, 2048, 32768), ], ) @pytest.mark.parametrize( "dtype, atol, rtol", [ (torch.float32, 1e-4, 1e-6), ], ) @pytest.mark.parametrize( "reference, offset, casting_mode", [ (LlamaRMSNorm, 0.0, "llama"), (GemmaRMSNorm, 1.0, "gemma"), pytest.param( BaseRMSNorm, 0.0, "none", marks=pytest.mark.skipif(device == "npu", reason="Ascend NPU does not support this test"), ), ], ) @pytest.mark.parametrize( "in_place", [ True, False, ], ) @pytest.mark.parametrize( "elementwise_affine", [ True, False, ], ) ``` #### Env <img width="696" height="251" alt="image" src="https://github.com/user-attachments/assets/c0cd523f-8a78-4205-8f18-6854477b9d0a" /> #### Results after code modification <img width="1178" height="836" alt="image" src="https://github.com/user-attachments/assets/b520faa0-f209-4ce0-9574-6000bb612f91" /> ### Benchmark test The test cases in `benchmark_rms_norm.py` should keep the same shapes as those in `test_rms_norm.py`. ```python common_configs = { ... "x_values": [2**i for i in range(12, 16)], ... } ``` #### Before Optimization forward <img width="1000" height="600" alt="rms_norm_speed_forward" src="https://github.com/user-attachments/assets/92a701fc-37a9-4298-a285-95174f97ef98" /> backward <img width="1000" height="600" alt="rms_norm_speed_backward" src="https://github.com/user-attachments/assets/8cfa175d-d6d3-4472-9b9d-7162e6a9df02" /> full <img width="1000" height="600" alt="rms_norm_speed_full" src="https://github.com/user-attachments/assets/f8248fc6-52a5-4963-a74e-54f7d85173cd" /> memory <img width="1000" height="600" alt="rms_norm_memory_full" src="https://github.com/user-attachments/assets/7d4b01bb-d425-4e2b-87ef-3d3b1638324b" /> [all_benchmark_data_raw.csv](https://github.com/user-attachments/files/26131388/all_benchmark_data_raw.csv) #### After Optimization forward <img width="1000" height="600" alt="rms_norm_speed_forward" src="https://github.com/user-attachments/assets/89b1a07f-273a-4d07-a6c2-5eb20c7abeb1" /> backward <img width="1000" height="600" alt="rms_norm_speed_backward" src="https://github.com/user-attachments/assets/4bd741c9-d401-4ade-ac7c-e3b8a241bc32" /> full <img width="1000" height="600" alt="rms_norm_speed_full" src="https://github.com/user-attachments/assets/c11f7b8c-a3d9-433e-a24e-ebdae343e2e9" /> memory <img width="1000" height="600" alt="rms_norm_memory_full" src="https://github.com/user-attachments/assets/e2bfedf2-65b2-440d-96fc-4f389cd5ddda" /> [all_benchmark_data_optimized.csv](https://github.com/user-attachments/files/26131403/all_benchmark_data_optimized.csv) - Hardware Type: Atlas 900 A2 PoD - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 830ed7c commit 781083b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/liger_kernel/ops/backends/_ascend/ops/rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _rms_norm_forward_kernel_tiled(
199199
X_block = X_block.to(tl.float32)
200200

201201
# Accumulate sum of squares (only for valid elements)
202-
sum_square += tl.sum(tl.where(mask, X_block * X_block, 0.0))
202+
sum_square += tl.sum(X_block * X_block)
203203

204204
# Compute rstd for this row
205205
mean_square = sum_square / n_cols
@@ -456,7 +456,7 @@ def _rms_norm_backward_kernel_tiled(
456456
m = dY_block
457457

458458
# Accumulate sum(m * X)
459-
sum_m_X += tl.sum(tl.where(mask, m * X_block, 0.0))
459+
sum_m_X += tl.sum(m * X_block)
460460

461461
# Compute the correction factor
462462
correction_factor = -(1.0 / n_cols) * rstd * rstd * sum_m_X

0 commit comments

Comments
 (0)