Commit 781083b
authored
[Optimize, NPU] Remove tl.where from _rms_norm_forward/backward_kernel_tiled() (#1153)
## Summary
When the mask has a large shape, `tl.where` is not NPU-friendly in
`triton-ascend`, leading to low kernel performance. When writing
kernels, it's best to use alternative logic instead. This can result in
significant performance improvements.
**Will these changes affect accuracy?** Since the masking operation has
already been applied when loading X_block, it will not affect the
calculation result during tl.sum.
## Testing Done
### Accuracy first
The shapes in `test_rms_norm.py` are too small to trigger the
`_rms_norm_forward_kernel_tiled` kernel, so we need a new configuration.
```python
@pytest.mark.parametrize(
"bs, sl, hd",
[
(2, 2048, 4096),
(2, 2048, 8192),
(2, 2048, 16384),
(2, 2048, 32768),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float32, 1e-4, 1e-6),
],
)
@pytest.mark.parametrize(
"reference, offset, casting_mode",
[
(LlamaRMSNorm, 0.0, "llama"),
(GemmaRMSNorm, 1.0, "gemma"),
pytest.param(
BaseRMSNorm,
0.0,
"none",
marks=pytest.mark.skipif(device == "npu", reason="Ascend NPU does not support this test"),
),
],
)
@pytest.mark.parametrize(
"in_place",
[
True,
False,
],
)
@pytest.mark.parametrize(
"elementwise_affine",
[
True,
False,
],
)
```
#### Env
<img width="696" height="251" alt="image"
src="https://github.com/user-attachments/assets/c0cd523f-8a78-4205-8f18-6854477b9d0a"
/>
#### Results after code modification
<img width="1178" height="836" alt="image"
src="https://github.com/user-attachments/assets/b520faa0-f209-4ce0-9574-6000bb612f91"
/>
### Benchmark test
The test cases in `benchmark_rms_norm.py` should keep the same shapes as
those in `test_rms_norm.py`.
```python
common_configs = {
...
"x_values": [2**i for i in range(12, 16)],
...
}
```
#### Before Optimization
forward
<img width="1000" height="600" alt="rms_norm_speed_forward"
src="https://github.com/user-attachments/assets/92a701fc-37a9-4298-a285-95174f97ef98"
/>
backward
<img width="1000" height="600" alt="rms_norm_speed_backward"
src="https://github.com/user-attachments/assets/8cfa175d-d6d3-4472-9b9d-7162e6a9df02"
/>
full
<img width="1000" height="600" alt="rms_norm_speed_full"
src="https://github.com/user-attachments/assets/f8248fc6-52a5-4963-a74e-54f7d85173cd"
/>
memory
<img width="1000" height="600" alt="rms_norm_memory_full"
src="https://github.com/user-attachments/assets/7d4b01bb-d425-4e2b-87ef-3d3b1638324b"
/>
[all_benchmark_data_raw.csv](https://github.com/user-attachments/files/26131388/all_benchmark_data_raw.csv)
#### After Optimization
forward
<img width="1000" height="600" alt="rms_norm_speed_forward"
src="https://github.com/user-attachments/assets/89b1a07f-273a-4d07-a6c2-5eb20c7abeb1"
/>
backward
<img width="1000" height="600" alt="rms_norm_speed_backward"
src="https://github.com/user-attachments/assets/4bd741c9-d401-4ade-ac7c-e3b8a241bc32"
/>
full
<img width="1000" height="600" alt="rms_norm_speed_full"
src="https://github.com/user-attachments/assets/c11f7b8c-a3d9-433e-a24e-ebdae343e2e9"
/>
memory
<img width="1000" height="600" alt="rms_norm_memory_full"
src="https://github.com/user-attachments/assets/e2bfedf2-65b2-440d-96fc-4f389cd5ddda"
/>
[all_benchmark_data_optimized.csv](https://github.com/user-attachments/files/26131403/all_benchmark_data_optimized.csv)
- Hardware Type: Atlas 900 A2 PoD
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence1 parent 830ed7c commit 781083b
1 file changed
+2
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
199 | 199 | | |
200 | 200 | | |
201 | 201 | | |
202 | | - | |
| 202 | + | |
203 | 203 | | |
204 | 204 | | |
205 | 205 | | |
| |||
456 | 456 | | |
457 | 457 | | |
458 | 458 | | |
459 | | - | |
| 459 | + | |
460 | 460 | | |
461 | 461 | | |
462 | 462 | | |
| |||
0 commit comments