Fix sort kernel launch bug when nrows exceed gridDim.y limit#3050
Merged
Conversation
Contributor
Author
|
@greenrazer obvious correctness fix, could you help review the code? |
Contributor
|
Thanks! |
john-sharratt
pushed a commit
to john-sharratt/candle
that referenced
this pull request
May 7, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The sort kernel launch has a bug when the number of rows (
nrows) to sort exceeds the limit ofgridDim.y(which is 65,535 according to the CUDA documentation). Issue reported here: EricLBuehler/candle-vllm#237This causes a
CUDA_ERROR_INVALID_VALUElaunch error. This PR provides a simple fix by changing the indexing fromblockIdx.y(limited bygridDim.y) toblockIdx.x(sincegridDim.xhas a much larger limit).Here is the original kernel launch, which uses
gridDim.yasnrows: