[PJRT] Basic PP and D2D activation transfer over sockets#5330
Draft
mstojkovicTT wants to merge 3 commits into
Draft
[PJRT] Basic PP and D2D activation transfer over sockets#5330mstojkovicTT wants to merge 3 commits into
mstojkovicTT wants to merge 3 commits into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Ticket
#2831
Problem description
We don't have a pipeline-parallel path running end-to-end through the
tt-xla / tt-mlir stack today. We need an initial plan + PoC
What's changed
This PR establishes the mechanism (activations move D2D over the fabric, all stages devices stay open) and proves it correct. The scheduling (async submit, microbatching, removing the per-hop device syncs, thread-safe resource management) is what's left to turn sequential placement into genuinely concurrent pipelining.
This whole mechanism is built on explicit per-device addressing, which SPMD takes away (has one virtual device and sends same program to all devices). We cant be in SPMD regime because:
PipelineParallelplaces stageion physical deviceiand crosses each boundary withh.to(next_device). This requires each physical chip to be an individually addressable PJRT device.src_device_idanddst_device_idoff those individual buffers to compute the submesh offsets.Main things this PR contains:
1.
experimental/pipeline.pyThis is the user-facing API and the only way to express the pipeline. It places stage
iontorch_xla.device(i)and, inforward, threads the activation withh.to(next_device).2. Intercepting the host round-trip in
buffer_instance.ccA cross-device move is expressed as
copyToHost(src, host_ptr)followed bycopyFromHost(dst, host_ptr)reusing the same host pointer. So:copyToHoststashes the source's still-on-device tensor handle, keyed by the host pointer (its refcount alone keeps the device buffer alive across the readback).copyFromHost(trySocketTransferFromHostPull) checks whether the incoming host pointer matches a stash and if it's a genuine cross-device hop, it sends the activation D2D over a socket and skips the host upload entirely.clearPendingDevicePulls()is needed because a stashed device-tensor handle that outlives the mesh device would crash in its destructor at teardown.3. Keeping all stages devices open at once:
client_instance.cc+module_builder.cc+loaded_executable_instance.ccA socket needs both endpoints alive simultaneously. The pre-existing code reshaped the single mesh device per executable (close + reopen), which would tear down any other stage's device and make sockets impossible. So:
getOrCreateSubmeshcarves each stage's device as a live submesh out of one long-lived parent mesh, cached by offset, multiple stages stay open concurrently and cached handles never go stale.module_builder.ccnow only reshapes the parent when the graph needs ≥ the current device count. If a stage uses fewer devices, it keeps the parent open and lets execution carve a submesh.loaded_executable_instance.ccdetects "this executable targets a strict subset of devices" and routes it to a submesh at the right offset instead of reshaping.getOrCreateSocketPair/closeLiveSubmeshesmanage the socket lifecycle (created lazily on first cross-submesh transfer, reused across runs, released before the submeshes close).Checklist