Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions example/torchstore_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,40 @@
import torchstore as ts
from monarch.actor import Actor, current_rank, endpoint, shutdown_context, this_host

# Run the example : python example/torchstore_rl.py
# Run the example:
# CUDA: python example/torchstore_rl.py
# XPU: python example/torchstore_rl.py (autodetected; oneCCL env from
# ~/env-3.sh and run_deepseek.sh's TCP block must already be set)


def set_cuda_visible_devices(devices: str) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = devices
def _accelerator() -> str:
"""Return ``"cuda"`` or ``"xpu"`` based on what's available, else ``"cpu"``."""
if torch.cuda.is_available():
return "cuda"
if hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
return "cpu"


_ACCEL = _accelerator()


def _set_visible_devices(devices: str) -> None:
"""Pin one process to a single accelerator tile.

Sets the env var that the active backend honors —
``CUDA_VISIBLE_DEVICES`` for CUDA, ``ZE_AFFINITY_MASK`` for XPU.
"""
if _ACCEL == "cuda":
os.environ["CUDA_VISIBLE_DEVICES"] = devices
elif _ACCEL == "xpu":
os.environ["ZE_AFFINITY_MASK"] = devices


class Learner(Actor):
def __init__(self):
# Trainer stays on CPU for the toy model — keeps the example
# focused on weight-sharing semantics, not training perf.
self.device = torch.device("cpu")
self.model = torch.nn.Linear(4, 4, bias=False, device=self.device)
self.optim = torch.optim.AdamW(
Expand All @@ -44,61 +69,60 @@ async def step(
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optim.step()
print("[learner] weights: ", self.model.state_dict())
# Put weights in to torch.store
await ts.put_state_dict(self.model.state_dict(), key="toy_app")


class Generator(Actor):
def __init__(self):
self.model = torch.nn.Linear(4, 4, bias=False, device="cuda")
self.device = torch.device(_ACCEL if _ACCEL != "cpu" else "cpu")
self.model = torch.nn.Linear(4, 4, bias=False, device=self.device)
self.index = current_rank()["gpus"]

@endpoint
async def update_weights(self):
print(f"[generator {self.index}] original weights: {self.model.state_dict()}")
# Fetch weights from torch.store
await ts.get_state_dict(key="toy_app", user_state_dict=self.model.state_dict())
print(f"[generator {self.index}] new weights: {self.model.state_dict()}")

@endpoint
async def generate(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
inputs = inputs.to("cuda")
inputs = inputs.to(self.device)
logits = self.model(inputs)
reward = torch.sum(logits)
return logits, reward


async def main():
"""
The example code shows how to use torchstore to share weights between
trainer/learner and generator apps. The weights are shared synchronously
between the two apps.
"""Trainer/generator weight-sharing demo. The chosen accelerator
(CUDA, XPU, or CPU fallback) is autodetected from what's available
on the host; transport selection inside TorchStore is automatic
(SHM intra-host, xccl on XPU, gloo otherwise).
"""
num_learners = 1
num_generators = 1

# TODO: Show weights re-sharding usecase.
learner_mesh = this_host().spawn_procs(
per_host={"gpus": num_learners},
bootstrap=partial(set_cuda_visible_devices, "0"),
bootstrap=partial(_set_visible_devices, "0"),
)
gen_mesh = this_host().spawn_procs(
per_host={"gpus": num_generators},
bootstrap=partial(set_cuda_visible_devices, "1"),
bootstrap=partial(_set_visible_devices, "1"),
)

await ts.initialize()

learner = learner_mesh.spawn("learner", Learner)
generators = gen_mesh.spawn("generator", Generator)

seed_device = _ACCEL if _ACCEL != "cpu" else "cpu"
logits, reward = await generators.generate.call_one(
torch.randn(4, 4, device="cuda")
torch.randn(4, 4, device=seed_device)
)
for _ in range(3):
await learner.step.call_one(logits, reward)
logits, reward = await generators.generate.call_one(
torch.randn(4, 4, device="cuda")
torch.randn(4, 4, device=seed_device)
)
await generators.update_weights.call_one()

Expand Down
214 changes: 214 additions & 0 deletions example/torchstore_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Multi-tensor state_dict round-trip via TorchStore.

Mirrors the pattern an RL pipeline uses for trainer/generator weight
sync:

- A ``Trainer``-style actor on its own proc_mesh calls
``ts.put_state_dict``.
- A ``Generator``-style actor on a different proc_mesh calls
``ts.get_state_dict`` to pull the weights back into pre-allocated
buffers.
- Many tensors per state_dict (LoRA adapter shape).

This exercises the per-tensor handshake-cache reuse in transports
that build a process group (gloo, xccl) — a single-tensor smoke
wouldn't catch a regression there.

Device autodetect: cuda > xpu > cpu. Run with::

python example/torchstore_state_dict.py

On Intel XPU (xccl/oneCCL backed) most systems need libfabric on
the TCP provider for oneCCL handshake to succeed::

unset FI_TCP_IFACE
unset CCL_KVS_MODE
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export CCL_ATL_OFI_PROVIDER=tcp
export CCL_PROCESS_LAUNCHER=hydra
export CCL_ZE_DISABLE_PORT_CHECK=1
export ZE_FLAT_DEVICE_HIERARCHY=FLAT
python example/torchstore_state_dict.py
"""

from __future__ import annotations

import asyncio
import logging
import os
from collections import OrderedDict
from functools import partial

import torch
import torchstore as ts
from monarch.actor import Actor, endpoint, shutdown_context, this_host

logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
log = logging.getLogger("state_dict_example")


def _accelerator() -> str:
if torch.cuda.is_available():
return "cuda"
if hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
return "cpu"


_ACCEL = _accelerator()
KEY = "policy_state_dict"


def _set_visible_devices(devices: str) -> None:
"""Pin one process to a single accelerator tile.

Sets the env var the active backend honors —
``CUDA_VISIBLE_DEVICES`` for CUDA, ``ZE_AFFINITY_MASK`` for XPU.
"""
if _ACCEL == "cuda":
os.environ["CUDA_VISIBLE_DEVICES"] = devices
elif _ACCEL == "xpu":
os.environ["ZE_AFFINITY_MASK"] = devices


def _make_state_dict(seed: int) -> OrderedDict[str, torch.Tensor]:
"""Build a small multi-tensor state_dict (LoRA-shaped).

Several distinct tensors so ``put_batch`` actually loops; a
single-tensor put would short-circuit handshake reuse.
"""
g = torch.Generator(device=_ACCEL).manual_seed(seed)
sd: OrderedDict[str, torch.Tensor] = OrderedDict()
for i in range(4):
sd[f"layer.{i}.lora_A.weight"] = torch.randn(8, 16, generator=g, device=_ACCEL)
sd[f"layer.{i}.lora_B.weight"] = torch.randn(16, 8, generator=g, device=_ACCEL)
return sd


class Trainer(Actor):
"""Publishes a state_dict via ts.put_state_dict."""

@endpoint
async def publish(self, seed: int) -> dict:
sd = _make_state_dict(seed)
await ts.put_state_dict(sd, KEY)
return {
k: {
"shape": tuple(v.shape),
"device": str(v.device),
"checksum": float(v.float().sum().item()),
}
for k, v in sd.items()
}


class Generator(Actor):
"""Pulls the state_dict via ts.get_state_dict into pre-allocated buffers."""

@endpoint
async def fetch(self, ref_seed_for_shapes: int) -> dict:
# Pre-allocate destination buffers with the right shapes;
# mirrors how an RL Generator pulls into its own
# model.state_dict().
sd = _make_state_dict(ref_seed_for_shapes)
for v in sd.values():
v.zero_()

sd = await ts.get_state_dict(KEY, user_state_dict=sd, strict=True)
return {
k: {
"shape": tuple(v.shape),
"device": str(v.device),
"checksum": float(v.float().sum().item()),
}
for k, v in sd.items()
}


async def main() -> int:
log.info("accelerator: %s", _ACCEL)
if _ACCEL == "cpu":
log.warning(
"no GPU detected; running on CPU still validates the actor "
"wiring and gloo path."
)

# Actors live on their own meshes → separate OS processes from
# the controller's storage volume. Required for xccl (oneCCL's
# internal KVS is per-process); harmless for gloo/SHM.
trainer_mesh = this_host().spawn_procs(
name="trainer",
per_host={"gpus": 1},
bootstrap=partial(_set_visible_devices, "0"),
)
gen_mesh = this_host().spawn_procs(
name="generator",
per_host={"gpus": 1},
bootstrap=partial(_set_visible_devices, "1"),
)

log.info("ts.initialize ...")
await ts.initialize()
log.info("initialize OK")

trainer = trainer_mesh.spawn("trainer", Trainer)
generator = gen_mesh.spawn("generator", Generator)

seed = 17
log.info("trainer.publish(seed=%d) — multi-tensor put_state_dict", seed)
src_info = next(iter(await trainer.publish.call(seed)))[1]
log.info("trainer wrote %d tensors", len(src_info))
for k, v in src_info.items():
log.info(
" %s: shape=%s device=%s checksum=%.6f",
k,
v["shape"],
v["device"],
v["checksum"],
)

log.info("generator.fetch — multi-tensor get_state_dict")
got_info = next(iter(await generator.fetch.call(seed)))[1]
log.info("generator got %d tensors", len(got_info))

failures: list[str] = []
for k, src in src_info.items():
if k not in got_info:
failures.append(f"missing: {k}")
continue
got = got_info[k]
if got["shape"] != src["shape"]:
failures.append(
f"{k}: shape mismatch src={src['shape']} got={got['shape']}"
)
if abs(got["checksum"] - src["checksum"]) > 1e-3:
failures.append(
f"{k}: checksum mismatch src={src['checksum']:.6f} "
f"got={got['checksum']:.6f}"
)

if failures:
for f in failures:
log.error(" %s", f)
log.error("STATE-DICT FAIL")
await ts.shutdown()
return 1

log.info("STATE-DICT PASS — %d tensors round-tripped", len(src_info))
await ts.shutdown()
return 0


if __name__ == "__main__":
import sys

rc = asyncio.run(main())
shutdown_context().get(timeout=2.0)
sys.exit(rc)
14 changes: 12 additions & 2 deletions torchstore/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from torchstore.transport.torchcomms.uniflow_buffer import TorchCommsTransportBuffer
from torchstore.transport.types import Request, TensorSlice
from torchstore.transport.xccl import xccl_available, XcclTransportBuffer

if TYPE_CHECKING:
from torchstore.strategy import StorageVolumeRef
Expand All @@ -39,14 +40,20 @@ class TransportType(Enum):
TorchComms = auto()
TorchCommsRDMA = TorchComms # Backward compatible alias
Gloo = auto()
XCCL = auto() # Intel oneCCL via torch.distributed; device-resident on XPU
SharedMemory = auto() # POSIX shared memory for same-host transfers


def get_available_transport(storage_volume_ref: "StorageVolumeRef") -> TransportType:
"""Determine the best available transport type for the given storage volume.

Prefers SharedMemory for same-host transfers, then TorchComms (Uniflow RDMA/NVLink),
then MonarchRDMA, then Gloo, otherwise falls back to MonarchRPC.
Order: SharedMemory (same-host) > TorchComms (Uniflow RDMA/NVLink) >
MonarchRDMA (ibverbs) > XCCL (XPU device-resident) > Gloo (TCP/CPU) >
MonarchRPC (last resort).

XCCL beats Gloo on XPU because it keeps tensors on device. Gloo is kept
as the universal cross-platform fallback for non-XPU hosts and as a
safety net if xccl init fails.
"""
# Prefer SharedMemory for same-host transfers
if SHM_ENABLED and is_local_to_volume(storage_volume_ref):
Expand All @@ -57,6 +64,8 @@ def get_available_transport(storage_volume_ref: "StorageVolumeRef") -> Transport
return TransportType.TorchComms
elif monarch_rdma_transport_available():
return TransportType.MonarchRDMA
elif xccl_available():
return TransportType.XCCL
elif gloo_available():
return TransportType.Gloo

Expand All @@ -82,6 +91,7 @@ def create_transport_buffer(storage_volume_ref: "StorageVolumeRef") -> Transport
TransportType.MonarchRPC: MonarchRPCTransportBuffer,
TransportType.MonarchRDMA: MonarchRDMATransportBuffer,
TransportType.Gloo: GlooTransportBuffer,
TransportType.XCCL: XcclTransportBuffer,
TransportType.SharedMemory: SharedMemoryTransportBuffer,
}

Expand Down
Loading