flash-attn: launch kernels on the caller's CUDA stream#3596
Open
jnises wants to merge 1 commit into
Open
Conversation
candle-flash-attn's run_mha and candle-flash-attn-v3's run_mha_v3
hard-code `cudaStream_t stream = 0` ("use the default stream"). When
the candle device runs on the context's legacy default stream this
happens to be the same stream and everything is ordered. But when
candle runs on a created stream (Device::new_cuda_with_stream, or any
multi-stream setup), the attention kernels launch on a different stream
than the one that produced Q/K/V and the one that consumes the output,
with no ordering between them -- a data race. It also makes these ops
impossible to record with CUDA graph stream capture: the legacy default
stream cannot be captured, and kernels launched outside the captured
stream do not become part of the graph.
huggingface#3565 already routes the buffer device_ptr guards through the device's
stream; this completes it by launching the kernels on that stream too.
Thread the stream through the extern "C" FFI as a trailing parameter
and pass the device stream from the Rust call sites. This mirrors the
upstream flash-attention API, where the launch stream is a caller
argument (the PyTorch bindings pass the current torch stream).
The extern "C" signature change is internal to each crate: the static
library is built by the crate's own build script and the Rust bindings
in the same crate are its only consumer, so the prototype and its call
sites move together.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
While attempting to capturing a CUDA graph for a model I was working on I ran into this stream issue.
candle-flash-attn's run_mha and candle-flash-attn-v3's run_mha_v3 hard-code
cudaStream_t stream = 0.When the candle device runs on the context's legacy default stream this happens to be the same stream and everything is ordered.
But when candle runs on a created stream (Device::new_cuda_with_stream, or any multi-stream setup), the attention kernels launch on a different stream than the one that produced Q/K/V and the one that consumes the output, with no ordering between them, a data race.
It also makes these ops impossible to record with CUDA graph stream capture: the legacy default stream cannot be captured, and kernels launched outside the captured stream do not become part of the graph.
#3565 already routes the buffer device_ptr guards through the device's stream; this PR completes it by launching the kernels on that stream too.