Skip to content

[PJRT] Basic PP and D2D activation transfer over sockets#5330

Draft
mstojkovicTT wants to merge 3 commits into
mainfrom
pp-basic-socket-d2d
Draft

[PJRT] Basic PP and D2D activation transfer over sockets#5330
mstojkovicTT wants to merge 3 commits into
mainfrom
pp-basic-socket-d2d

Conversation

@mstojkovicTT

Copy link
Copy Markdown
Contributor

Ticket

#2831

Problem description

We don't have a pipeline-parallel path running end-to-end through the
tt-xla / tt-mlir stack today. We need an initial plan + PoC

What's changed

This PR establishes the mechanism (activations move D2D over the fabric, all stages devices stay open) and proves it correct. The scheduling (async submit, microbatching, removing the per-hop device syncs, thread-safe resource management) is what's left to turn sequential placement into genuinely concurrent pipelining.

This whole mechanism is built on explicit per-device addressing, which SPMD takes away (has one virtual device and sends same program to all devices). We cant be in SPMD regime because:

  • PipelineParallel places stage i on physical device i and crosses each boundary with h.to(next_device). This requires each physical chip to be an individually addressable PJRT device.
  • The socket transfer is triggered by a single source buffer on a known device being read back to a host pointer, then re-uploaded to a single destination buffer on another known device. The plugin reads src_device_id and dst_device_id off those individual buffers to compute the submesh offsets.

Main things this PR contains:

1. experimental/pipeline.py

This is the user-facing API and the only way to express the pipeline. It places stage i on torch_xla.device(i) and, in forward, threads the activation with h.to(next_device).

2. Intercepting the host round-trip in buffer_instance.cc

A cross-device move is expressed as copyToHost(src, host_ptr) followed by copyFromHost(dst, host_ptr) reusing the same host pointer. So:

  • copyToHost stashes the source's still-on-device tensor handle, keyed by the host pointer (its refcount alone keeps the device buffer alive across the readback).
  • copyFromHost (trySocketTransferFromHostPull) checks whether the incoming host pointer matches a stash and if it's a genuine cross-device hop, it sends the activation D2D over a socket and skips the host upload entirely.

clearPendingDevicePulls() is needed because a stashed device-tensor handle that outlives the mesh device would crash in its destructor at teardown.

3. Keeping all stages devices open at once: client_instance.cc + module_builder.cc + loaded_executable_instance.cc

A socket needs both endpoints alive simultaneously. The pre-existing code reshaped the single mesh device per executable (close + reopen), which would tear down any other stage's device and make sockets impossible. So:

  • getOrCreateSubmesh carves each stage's device as a live submesh out of one long-lived parent mesh, cached by offset, multiple stages stay open concurrently and cached handles never go stale.
  • module_builder.cc now only reshapes the parent when the graph needs the current device count. If a stage uses fewer devices, it keeps the parent open and lets execution carve a submesh.
  • loaded_executable_instance.cc detects "this executable targets a strict subset of devices" and routes it to a submesh at the right offset instead of reshaping.
  • getOrCreateSocketPair / closeLiveSubmeshes manage the socket lifecycle (created lazily on first cross-submesh transfer, reused across runs, released before the submeshes close).

Checklist

  • Wait for tt-mlir uplift of sockets API

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant