Skip to content

candle-core: add Device::enable_peer_access for cross-CUDA P2P transfers#3525

Open
toddwbucy wants to merge 2 commits into
huggingface:mainfrom
toddwbucy:feat/peer-access-upstream
Open

candle-core: add Device::enable_peer_access for cross-CUDA P2P transfers#3525
toddwbucy wants to merge 2 commits into
huggingface:mainfrom
toddwbucy:feat/peer-access-upstream

Conversation

@toddwbucy

Copy link
Copy Markdown

Fixes #3524.

Adds explicit `Device::enable_peer_access(&other)` so GPU-direct cross-card tensor operations (`Tensor::to_device(&other_cuda_device)`) succeed instead of erroring with `CUDA_ERROR_INVALID_CONTEXT` on first use.

What

  • `CudaDevice::enable_peer_access(&self, other: &Self) -> Result<()>` — calls `cuCtxEnablePeerAccess` in both directions (`self ←→ other`), bound to the appropriate context for each direction.
  • `Device::enable_peer_access(&self, other: &Self) -> Result<()>` — public surface, cuda-only (`#[cfg(feature = "cuda")]`); errors clearly if either side isn't a CUDA device.

Idempotent

  • Same-ordinal pairs no-op, returning `Ok(())` — same-context "peer" access is meaningless.
  • Repeat calls between the same context pair are safe — the driver's `CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED` (704) is folded into `Ok(())` inside the helper. So callers can put `enable_peer_access` on a hot path without worrying about state tracking.

Why explicit, not auto-enable in `BackendDevice::new`

Per #3524 I considered (A) explicit opt-in vs (B) opportunistic auto-enable in `BackendDevice::new`. Going with (A) here:

  • Peer access is a real driver-state mutation with costs (UVA coordination on platforms that have it, no-op rejection on heterogeneous configs). Surfacing the opt-in keeps those visible.
  • (B) requires global state — a registry of all already-constructed `CudaDevice`s — which is a much larger surgery and creates surprising side effects (every `Device::new_cuda(N)` would mutate driver state for every other extant CUDA device).
  • (A) is forward-compatible with (B) — if you later want auto-enable, `BackendDevice::new` can just call this same helper.

Happy to switch to (B) instead if maintainers prefer that direction; the underlying `enable_peer_access_one_way` helper is the same either way.

Errors

Returns the underlying `DriverError` if the device pair doesn't support peer access (some heterogeneous or IOMMU-isolated configurations) or if either context is in a terminal state. Callers wanting to probe support before attempting can use `cuDeviceCanAccessPeer` separately; not wrapped here.

Validation

  • `cargo build -p candle-core --features cuda` — clean
  • Tested in a downstream project's topology bench (2× A6000 + NVLink, also tested over PCIe with NVLink bridge physically removed) where `Tensor::to_device(&other_cuda_device)` previously errored with `INVALID_CONTEXT`. With this patch + a one-time `encoder_device.enable_peer_access(&decoder_device)?` call before the first transfer, the cross-card path succeeds and routes through NVLink (or PCIe P2P).

Adjacent

This is the third small fix from our Jina V4 + multi-GPU work. The other two are #3520 (qwen2 RoPE fp32 cos/sin tables) and #3521 (FA v2.8.3 vendored kernel bump). All three were independently discovered while bringing a Jina V4 embedder up under candle on a 2-A6000 rig; happy to provide more context on any of them.

@EricLBuehler EricLBuehler left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @toddwbucy!

Thanks for the PR - it solves a real bug introduced in #3312 where context was not properly configured. In fact, once we get this PR merged, we should enhance the current to_device path in candle to try peer-enable (new) + peer copy (as current); on failure, fall back to host staging (the old path). However, that should be a later PR.

This PR looks nice, but I have a few review comments.

One thing I did not mention in the specific comments is the added documentation. The code added is paired with plenty of documentation, but I think that in this case it is a bit too much. If you look at the codebase, you will notice that we have more self-explanatory code with the goal of avoiding such long comment blocks. Of course, this PR does introduce methods that bring with them complex semantics, so documenting that is fine, but can you please correct the comments where they are overly detailed (or forced to wrap at ~80 columns although the codebase does not use that convention) or use special characters?

Comment thread candle-core/src/device.rs
/// `Metal`), or if the underlying driver call rejects the request
/// (peer access unsupported on this hardware pair, etc.).
#[cfg(feature = "cuda")]
pub fn enable_peer_access(&self, other: &Self) -> Result<()> {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gate (#[cfg(feature = "cuda")]) breaks the file's convention where all methods are always exposed. They compile unconditionally and bail at runtime (see the methods on Device).

// Restore self as current — bind_to_thread above pushed `other`
// onto this OS thread; callers that proceed with `self` work
// immediately after this method shouldn't have to re-bind.
self.context.bind_to_thread().w()?;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we bail out with enable_peer_access_one_way?, then the final self restoration does not occur. Can you please verify correctness of this implementation & edge case?

It seems that the happy path is correct, and the only behavioral gap is that on a failure in the second direction the thread is left bound to other instead of self. It's not a soundness bug, and it's more of a robustness/consistency cleanup.

/// IOMMU-isolated configurations) or if either context is in a
/// terminal state. Check `cuDeviceCanAccessPeer` separately if you
/// need to probe support before attempting to enable it.
pub fn enable_peer_access(&self, other: &Self) -> Result<()> {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method mutates the calling thread's current CUDA context (ends bound to self, regardless of what was bound before). The doc comment explains the internal rebind but doesn't warn the caller that their thread's current context changes as a side effect. I think it may be worth one line in the # Errors/notes section, since it's a surprising side effect for a method named enable_*.

Comment thread candle-core/src/cuda_backend/device.rs Outdated
pub fn enable_peer_access(&self, other: &Self) -> Result<()> {
let self_ord = self.context.ordinal();
let other_ord = other.context.ordinal();
if self_ord == other_ord {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This compares ordinal(), not the context pointer. Two distinct CudaDevice instances on the same physical GPU but with different contexts would short-circuit without enabling anything, yet a transfer between them still has src_ctx != dst_ctx and would route through memcpy_peer_async.

This is of course a very niche configuration (candle normally has one context per ordinal anyway) and cuCtxEnablePeerAccess between two contexts on the same device would error anyway, so the early return is defensible.

However, the doc says "same physical device … no peer to enable," which glosses over it. Perhaps we should note the assumption?

@toddwbucy toddwbucy force-pushed the feat/peer-access-upstream branch from c03576b to b01538e Compare June 3, 2026 13:17
`Tensor::to_device(&other_cuda_device)` dispatches to
`CudaStorage::transfer_to_device` -> `cudarc::CudaStream::clone_dtod` ->
`memcpy_peer_async` when the two devices have different contexts. That
requires `cuCtxEnablePeerAccess` between the two contexts first; without
it the driver rejects with `CUDA_ERROR_INVALID_CONTEXT`, so any
GPU-direct cross-`CudaDevice` transfer fails on first use.

Add an explicit opt-in:
  - `CudaDevice::enable_peer_access(&self, other)` enables both directions,
    bound to the appropriate context for each, and rebinds `self` on every
    exit path (including errors).
  - `Device::enable_peer_access` delegates when both are CUDA and errors
    otherwise. It compiles unconditionally and bails at runtime without the
    `cuda` feature (dummy-backend stub), matching the other Device methods.

Idempotent: handles sharing one context are a no-op, and repeat calls on
an already-enabled pair return `Ok(())` (the driver's
`CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED` is folded into success).

Not auto-enabled in `BackendDevice::new`: that would need global tracking
of every constructed `CudaDevice` -- a much larger change.

Tested: discovered during a topology bench (encoder GPU0 <-> decoder
GPU1, NVLink-bridged) where `Tensor::to_device` errored with
INVALID_CONTEXT. With this patch + a one-time
`encoder_device.enable_peer_access(&decoder_device)?` call before
the first transfer, the cross-card path succeeds and routes through
the NVLink bridge.

Builds: `cargo build -p candle-core --features cuda`: clean.
@toddwbucy toddwbucy force-pushed the feat/peer-access-upstream branch from b01538e to 861b165 Compare June 3, 2026 13:25
@toddwbucy

Copy link
Copy Markdown
Author

Thanks @EricLBuehler — addressed all of it and rebased onto current main:

  • ordinal → context: now compares context pointers, so same-GPU/different-context handles aren't wrongly short-circuited.
  • error-path restore: the two-way enable is captured and self is unconditionally rebound afterward, so a failure in either direction can't leave the thread on other (verified all four exit paths).
  • side effect: documented under # Side effects (thread ends bound to self).
  • cfg gate: dropped; added a runtime-bail stub in the dummy backend so it compiles unconditionally like the other Device methods.
  • docs: trimmed throughout, dropped the 80-col wrapping and the unicode arrow.
  • Verified with a GPU0/1 P2P round-trip (to_device 0→1→0).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cross-CUDA Tensor::to_device fails with CUDA_ERROR_INVALID_CONTEXT (no peer-access enable)

3 participants