Skip to content

[RMSNorm] Fix JIT recompilation by removing tl.constexpr on rows_per_program & Cleanup Block kernel interface#988

Merged
Tcc0403 merged 2 commits intolinkedin:mainfrom
niyunsheng:rms_norm_block_backward
Dec 24, 2025
Merged

[RMSNorm] Fix JIT recompilation by removing tl.constexpr on rows_per_program & Cleanup Block kernel interface#988
Tcc0403 merged 2 commits intolinkedin:mainfrom
niyunsheng:rms_norm_block_backward

Conversation

@niyunsheng
Copy link
Copy Markdown
Contributor

Summary

This PR optimizes the JIT compilation behavior for _rms_norm_backward_kernel and cleans up the interface for _block_rms_norm_backward_kernel.

  1. Avoid JIT Recompilation: Removes tl.constexpr from the rows_per_program argument in _rms_norm_backward_kernel.

  2. Interface Cleanup: Removes the unused rows_per_program argument from _block_rms_norm_backward_kernel.

Details

  1. Fix for Dynamic Shapes in _rms_norm_backward_kernel. Currently, rows_per_program is marked as tl.constexpr, but it is used within a standard dynamic range loop (not tl.static_range).
  • Issue: The tl.constexpr hint provides no loop unrolling benefits in this context because the loop bounds are determined at runtime (dependent on n_rows and program_id). However, Triton still treats the parameter as part of the kernel signature.
  • Impact: In dynamic shape scenarios (where rows_per_program changes with input size), this unnecessarily triggers JIT recompilation for every new shape, causing severe cache thrashing and CPU overhead without any performance gain.
  • Fix: Removing tl.constexpr allows the compiled kernel to be reused across different rows_per_program values.
  1. Cleanup in _block_rms_norm_backward_kernel. The rows_per_program argument was unused in the block-wise implementation. It has been removed to avoid signature pollution and confusion.

Testing Done

Verified that the changes do not introduce performance regressions. The benchmark shows stable latency across different hidden sizes.

Performance Benchmark:

Hidden Size Latency (ms) P50 (ms)
1024.00 0.13 0.11
2048.00 0.12 0.12
4096.00 0.12 0.12
8192.00 0.12 0.11
16384.00 0.18 0.18
32768.00 1.37 1.39
  • Hardware Type: NVIDIA A100-SXM4-80GB
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch!

@Tcc0403 Tcc0403 merged commit 77949e0 into linkedin:main Dec 24, 2025
3 of 7 checks passed
@niyunsheng niyunsheng deleted the rms_norm_block_backward branch December 25, 2025 00:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants