[RMSNorm] Fix JIT recompilation by removing tl.constexpr on rows_per_program & Cleanup Block kernel interface#988
Merged
Tcc0403 merged 2 commits intolinkedin:mainfrom Dec 24, 2025
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR optimizes the JIT compilation behavior for
_rms_norm_backward_kerneland cleans up the interface for_block_rms_norm_backward_kernel.Avoid JIT Recompilation: Removes
tl.constexprfrom therows_per_programargument in_rms_norm_backward_kernel.Interface Cleanup: Removes the unused
rows_per_programargument from_block_rms_norm_backward_kernel.Details
_rms_norm_backward_kernel. Currently,rows_per_programis marked astl.constexpr, but it is used within a standard dynamicrangeloop (nottl.static_range).tl.constexprhint provides no loop unrolling benefits in this context because the loop bounds are determined at runtime (dependent onn_rowsandprogram_id). However, Triton still treats the parameter as part of the kernel signature.rows_per_programchanges with input size), this unnecessarily triggers JIT recompilation for every new shape, causing severe cache thrashing and CPU overhead without any performance gain.tl.constexprallows the compiled kernel to be reused across differentrows_per_programvalues._block_rms_norm_backward_kernel. Therows_per_programargument was unused in the block-wise implementation. It has been removed to avoid signature pollution and confusion.Testing Done
Verified that the changes do not introduce performance regressions. The benchmark shows stable latency across different hidden sizes.
Performance Benchmark:
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence