Skip to content

flash-attn: launch kernels on the caller's CUDA stream#3596

Open
jnises wants to merge 1 commit into
huggingface:mainfrom
jnises:flash-attn-caller-stream
Open

flash-attn: launch kernels on the caller's CUDA stream#3596
jnises wants to merge 1 commit into
huggingface:mainfrom
jnises:flash-attn-caller-stream

Conversation

@jnises

@jnises jnises commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

While attempting to capturing a CUDA graph for a model I was working on I ran into this stream issue.

candle-flash-attn's run_mha and candle-flash-attn-v3's run_mha_v3 hard-code cudaStream_t stream = 0.
When the candle device runs on the context's legacy default stream this happens to be the same stream and everything is ordered.
But when candle runs on a created stream (Device::new_cuda_with_stream, or any multi-stream setup), the attention kernels launch on a different stream than the one that produced Q/K/V and the one that consumes the output, with no ordering between them, a data race.
It also makes these ops impossible to record with CUDA graph stream capture: the legacy default stream cannot be captured, and kernels launched outside the captured stream do not become part of the graph.

#3565 already routes the buffer device_ptr guards through the device's stream; this PR completes it by launching the kernels on that stream too.

candle-flash-attn's run_mha and candle-flash-attn-v3's run_mha_v3
hard-code `cudaStream_t stream = 0` ("use the default stream"). When
the candle device runs on the context's legacy default stream this
happens to be the same stream and everything is ordered. But when
candle runs on a created stream (Device::new_cuda_with_stream, or any
multi-stream setup), the attention kernels launch on a different stream
than the one that produced Q/K/V and the one that consumes the output,
with no ordering between them -- a data race. It also makes these ops
impossible to record with CUDA graph stream capture: the legacy default
stream cannot be captured, and kernels launched outside the captured
stream do not become part of the graph.

huggingface#3565 already routes the buffer device_ptr guards through the device's
stream; this completes it by launching the kernels on that stream too.
Thread the stream through the extern "C" FFI as a trailing parameter
and pass the device stream from the Rust call sites. This mirrors the
upstream flash-attention API, where the launch stream is a caller
argument (the PyTorch bindings pass the current torch stream).

The extern "C" signature change is internal to each crate: the static
library is built by the crate's own build script and the Rust bindings
in the same crate are its only consumer, so the prototype and its call
sites move together.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant