Skip to content

Fix sort kernel launch bug when nrows exceed gridDim.y limit#3050

Merged
greenrazer merged 1 commit into
huggingface:mainfrom
guoqingbao:upstream
Aug 11, 2025
Merged

Fix sort kernel launch bug when nrows exceed gridDim.y limit#3050
greenrazer merged 1 commit into
huggingface:mainfrom
guoqingbao:upstream

Conversation

@guoqingbao

Copy link
Copy Markdown
Contributor

The sort kernel launch has a bug when the number of rows (nrows) to sort exceeds the limit of gridDim.y (which is 65,535 according to the CUDA documentation). Issue reported here: EricLBuehler/candle-vllm#237

This causes a CUDA_ERROR_INVALID_VALUE launch error. This PR provides a simple fix by changing the indexing from blockIdx.y (limited by gridDim.y) to blockIdx.x (since gridDim.x has a much larger limit).

Here is the original kernel launch, which uses gridDim.y as nrows:

         let ncols_pad = next_power_of_2(ncols);
            let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
            let cfg = LaunchConfig {
                grid_dim: (1, nrows as u32, 1),
                grid_dim: (nrows as u32, 1, 1),
                block_dim: (ncols_pad as u32, 1, 1),
                shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
            };

@guoqingbao

Copy link
Copy Markdown
Contributor Author

@greenrazer obvious correctness fix, could you help review the code?

@greenrazer

Copy link
Copy Markdown
Contributor

Thanks!

@greenrazer greenrazer merged commit 1829812 into huggingface:main Aug 11, 2025
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants