Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchft/_test/diloco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def setup_outer_optimizers(self) -> list[torch.optim.Optimizer]:
layers.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
)
# pyrefly: ignore [bad-return]
return outer_optimizers

def setup_pg(self) -> FakeProcessGroupWrapper:
Expand Down Expand Up @@ -283,4 +284,5 @@ def train_loop(self) -> dict[str, Any]:
if self.manager.current_step() >= 4:
break

# pyrefly: ignore [bad-return]
return all_state_dicts
1 change: 1 addition & 0 deletions torchft/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def gen_views(inp: torch.Tensor) -> list[tuple[int, ...]]:
if size % m == 0:
views.append((m, size // m))

# pyrefly: ignore [bad-return]
return views


Expand Down
2 changes: 2 additions & 0 deletions torchft/checkpointing/http_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
else:
out.append(v)
else:
# pyrefly: ignore [bad-argument-type]
out.append(v)
# pyrefly: ignore [bad-return]
return out


Expand Down
1 change: 1 addition & 0 deletions torchft/checkpointing/pg_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor:
tensor = recv(path, v.local)
values.append(DTensor(tensor, v.spec, requires_grad=False))
else:
# pyrefly: ignore [bad-argument-type]
values.append(v)

for work in works:
Expand Down
15 changes: 13 additions & 2 deletions torchft/checkpointing/transport_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def run(rank: int) -> CheckpointTransport[dict[str, object]]:
)
else:
got = transport.recv_checkpoint(
src_rank=0, metadata=metadata, step=1, timeout=timedelta(seconds=10)
# pyrefly: ignore [unbound-name]
src_rank=0,
# pyrefly: ignore [unbound-name]
metadata=metadata,
step=1,
timeout=timedelta(seconds=10),
)
assertStateDictEqual(self, got, state_dict)

Expand All @@ -106,7 +111,12 @@ def run(rank: int) -> CheckpointTransport[dict[str, object]]:
)
elif rank == 2:
got = transport.recv_checkpoint(
src_rank=0, metadata=metadata, step=2, timeout=timedelta(seconds=10)
# pyrefly: ignore [unbound-name]
src_rank=0,
# pyrefly: ignore [unbound-name]
metadata=metadata,
step=2,
timeout=timedelta(seconds=10),
)
assertStateDictEqual(self, got, state_dict)

Expand All @@ -118,6 +128,7 @@ def run(rank: int) -> CheckpointTransport[dict[str, object]]:
with self.assertRaisesRegex(Exception, TIMEOUT_REGEX):
transport.recv_checkpoint(
src_rank=0,
# pyrefly: ignore [unbound-name]
metadata=metadata,
step=3,
timeout=timedelta(milliseconds=10),
Expand Down
1 change: 1 addition & 0 deletions torchft/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, length: int) -> None:
def __len__(self) -> int:
return self.length

# pyrefly: ignore [bad-override-param-name]
def __getitem__(self, idx: int) -> int:
return idx

Expand Down
6 changes: 6 additions & 0 deletions torchft/diloco_regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def setup_outer_optimizers(self) -> list[torch.optim.Optimizer]:
outer_optimizers.append(
MockOptimizer(self.model.layers[i].parameters(), lr=self.outer_lr)
)
# pyrefly: ignore [bad-return]
return outer_optimizers

def train_loop(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -366,11 +367,13 @@ def test_diloco_mocked_updates(
lighthouse.shutdown()

# Check results against fixture or validate parameter updates
# pyrefly: ignore [bad-argument-type]
compared_with_fixture = self._check_against_fixture(results)

if not compared_with_fixture:
# If no fixture comparison was done, validate parameters directly
self._validate_parameter_updates(
# pyrefly: ignore [bad-argument-type]
results[0][0],
n_fragments,
sync_every,
Expand Down Expand Up @@ -454,6 +457,7 @@ def test_diloco_mocked_failure_recovery(
lighthouse.shutdown()

# Check results against fixture or validate failure recovery
# pyrefly: ignore [bad-argument-type]
compared_with_fixture = self._check_against_fixture(results)

if not compared_with_fixture:
Expand All @@ -466,7 +470,9 @@ def test_diloco_mocked_failure_recovery(

# Verify that both replicas have the same global parameters at the end
# Extract the global parameter history from both replicas
# pyrefly: ignore [bad-index]
rep0_global = results[0][0]["global_parameter_history"]
# pyrefly: ignore [bad-index]
rep1_global = results[1][0]["global_parameter_history"]

# Get the last step in both histories
Expand Down
1 change: 1 addition & 0 deletions torchft/examples/slurm/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _make_app(replica_id: int, cli_args: argparse.Namespace) -> specs.AppDef:

# gloo
if os.environ.get("GLOO_SOCKET_IFNAME") is not None:
# pyrefly: ignore [unsupported-operation]
env["GLOO_SOCKET_IFNAME"] = os.environ.get("GLOO_SOCKET_IFNAME")

# application log levels
Expand Down
2 changes: 2 additions & 0 deletions torchft/fsdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def apply_set_all_reduce_hook(m: nn.Module) -> None:
ColwiseParallel(),
)
shard_model = fully_shard(model, mesh=fsdp_mesh)
# pyrefly: ignore [missing-attribute]
shard_model.apply(apply_set_all_reduce_hook)
# pyrefly: ignore [not-callable]
shard_model(batch).mean().backward()

# pyre-ignore[56]: Pyre was not able to infer the type of argument
Expand Down
1 change: 1 addition & 0 deletions torchft/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@


class _IPv6HTTPServer(ThreadingHTTPServer):
# pyrefly: ignore [bad-override-mutable-attribute]
address_family: socket.AddressFamily = socket.AF_INET6
request_queue_size: int = 1024
12 changes: 11 additions & 1 deletion torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,11 @@ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
# LocalSGD only guarantees that the model is consistent across
# replicas but uses separate optimizer states.
torch.testing.assert_close(
state_dict[0]["model"], state_dicts[0][0]["model"], check_device=False
# pyrefly: ignore [bad-index]
state_dict[0]["model"],
# pyrefly: ignore [bad-index]
state_dicts[0][0]["model"],
check_device=False,
)

self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
Expand Down Expand Up @@ -283,6 +287,7 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
lighthouse.shutdown()

rep0, rep1 = state_dicts
# pyrefly: ignore [bad-argument-type]
assert_equal_global_state(1, rep1, rep0)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
Expand Down Expand Up @@ -361,6 +366,7 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
# of step Manager Step 2
#
# Outer optimizer and global model should be the same
# pyrefly: ignore [bad-argument-type]
assert_equal_global_state(1, rep1, rep0)

self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
Expand Down Expand Up @@ -431,6 +437,7 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:

rep0, rep1 = state_dicts

# pyrefly: ignore [bad-argument-type]
assert_equal_global_state(2, rep1, rep0)

self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
Expand Down Expand Up @@ -513,7 +520,9 @@ def test_streaming_diloco_upscale(

rep0, rep1, rep2 = state_dicts

# pyrefly: ignore [bad-argument-type]
assert_equal_global_state(n_fragments, rep0, rep1)
# pyrefly: ignore [bad-argument-type]
assert_equal_global_state(n_fragments, rep0, rep2)

for event_injector in event_injectors:
Expand Down Expand Up @@ -585,6 +594,7 @@ def test_streaming_diloco_commit_failure(

rep0, rep1 = state_dicts

# pyrefly: ignore [bad-argument-type]
assert_equal_global_state(n_fragments, rep0, rep1)

for event_injector in event_injectors:
Expand Down
5 changes: 5 additions & 0 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def __init__(
num_chunks=0,
)

# pyrefly: ignore [invalid-type-var]
self._checkpoint_transport: CheckpointTransport[Dict[str, T]] = (
checkpoint_transport
)
Expand Down Expand Up @@ -1092,6 +1093,7 @@ def __init__(self, value: object) -> None:
super().__init__()
self._value = value

# pyrefly: ignore [bad-override]
def value(self) -> object:
return self._value

Expand All @@ -1102,6 +1104,7 @@ def then(
"This future is only supposed to be used in callback chain to extract the value"
)

# pyrefly: ignore [bad-override]
def wait(self) -> object:
raise NotImplementedError(
"This future is only supposed to be used in callback chain to extract the value"
Expand Down Expand Up @@ -1180,10 +1183,12 @@ def then(
managed_work._managed_fut_tail = self._next
return cast(torch.futures.Future[S], self._next)

# pyrefly: ignore [bad-override]
def wait(self) -> object:
assert self._fut
return self._fut.wait()

# pyrefly: ignore [bad-override]
def value(self) -> object:
raise NotImplementedError(
"This future is supposed to be used to create callback chain"
Expand Down
4 changes: 4 additions & 0 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,14 +405,17 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:

bad_fut = torch.futures.Future()
bad_fut.set_exception(RuntimeError("injected failure"))
# pyrefly: ignore [missing-attribute]
manager._pg.allreduce.return_value.get_future.return_value = bad_fut
manager.allreduce(torch.tensor([1.0])).wait()
# pyrefly: ignore [missing-attribute]
self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 2)
self.assertTrue(manager._errored)
self.assertFalse(manager.should_commit())
self.assertTrue(manager._errored)

# cleanup
# pyrefly: ignore [missing-attribute]
manager._pg.allreduce.reset_mock(return_value=True)

# recover on next step
Expand Down Expand Up @@ -549,6 +552,7 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
self.assertIsNone(manager.errored())

e = RuntimeError("injected failure")
# pyrefly: ignore [bad-argument-type]
fut.set_exception(e)
error = manager.errored()
assert error is not None
Expand Down
3 changes: 3 additions & 0 deletions torchft/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,18 @@ def zero_grad(self, set_to_none: bool = True) -> None:
self.manager.start_quorum()
self.optim.zero_grad(set_to_none)

# pyrefly: ignore [bad-override]
def step(self, closure: Optional[object] = None) -> None:
assert closure is None, "optimizers that use closures are not supported"
if self.manager.should_commit():
self.optim.step()

@property
# pyrefly: ignore [bad-override]
def param_groups(self) -> List[Dict[str, Any]]:
return self.optim.param_groups

@property
# pyrefly: ignore [bad-override]
def state(self) -> Mapping[torch.Tensor, object]:
return self.optim.state
2 changes: 2 additions & 0 deletions torchft/otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
def export(self, batch: Sequence[ReadableLogRecord]) -> LogRecordExportResult:
for e in self._exporters:
e.export(batch)
# pyrefly: ignore [missing-attribute]
return LogRecordExportResult.SUCCESS

def shutdown(self) -> None:
Expand Down Expand Up @@ -83,6 +84,7 @@ def setup_logger(name: str) -> None:

exporter = TeeLogExporter(
exporters=[
# pyrefly: ignore [bad-instantiation]
ConsoleLogRecordExporter(),
OTLPLogExporter(
timeout=5,
Expand Down
13 changes: 13 additions & 0 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def allreduce(
"""
raise NotImplementedError("not implemented")

# pyrefly: ignore [bad-override]
def allreduce_coalesced(
self,
tensors: List[torch.Tensor],
Expand Down Expand Up @@ -355,6 +356,7 @@ def register(self, name: str) -> "ProcessGroup":
)

@property
# pyrefly: ignore [bad-override]
def group_name(self) -> str:
if self._group_name is None:
raise ValueError("ProcessGroup name not set")
Expand Down Expand Up @@ -655,6 +657,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
if self._global_ranks:
backend_class.options.global_ranks_in_group = self._global_ranks
if self._group_rank and self._group_world_size:
# pyrefly: ignore [bad-assignment]
backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"

pg._register_backend(
Expand Down Expand Up @@ -823,6 +826,7 @@ def _opts_hook(self, opts: T) -> T:
# crash the whole program.
if hasattr(opts, "timeout"):
# apply default timeout to disable
# pyrefly: ignore [missing-attribute]
opts.timeout = AllgatherOptions().timeout
return opts

Expand Down Expand Up @@ -861,6 +865,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
if self._global_ranks:
opts.global_ranks_in_group = self._global_ranks
if self._group_rank and self._group_world_size:
# pyrefly: ignore [bad-assignment]
opts.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"

pg = BaseProcessGroup(store, rank, world_size)
Expand Down Expand Up @@ -939,6 +944,7 @@ def _opts_hook(self, opts: T) -> T:
# crash the whole program.
if hasattr(opts, "timeout"):
# apply default timeout to disable
# pyrefly: ignore [missing-attribute]
opts.timeout = AllgatherOptions().timeout
return opts

Expand Down Expand Up @@ -1529,6 +1535,7 @@ def configure(
else -1
)

# pyrefly: ignore [bad-assignment]
self._p = p = ctx.Process(
target=self._worker,
args=(
Expand Down Expand Up @@ -1934,8 +1941,10 @@ class _PickleSafeOptions:
@classmethod
def safe_args(cls, args: T) -> T:
if isinstance(args, tuple):
# pyrefly: ignore [bad-return]
return tuple(cls.safe_args(arg) for arg in args)
elif isinstance(args, list):
# pyrefly: ignore [bad-return]
return [cls.safe_args(arg) for arg in args]
elif isinstance(
args,
Expand All @@ -1949,17 +1958,21 @@ def safe_args(cls, args: T) -> T:
ReduceScatterOptions,
),
):
# pyrefly: ignore [bad-return]
return cls.from_torch(args)
else:
return args

@classmethod
def unsafe_args(cls, args: T) -> T:
if isinstance(args, tuple):
# pyrefly: ignore [bad-return]
return tuple(cls.unsafe_args(arg) for arg in args)
elif isinstance(args, list):
# pyrefly: ignore [bad-return]
return [cls.unsafe_args(arg) for arg in args]
elif isinstance(args, cls):
# pyrefly: ignore [bad-return]
return args.to_torch()
else:
return args
Expand Down
1 change: 1 addition & 0 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ def test_functional_collectives(self) -> None:

try:
t = torch.zeros(10)
# pyrefly: ignore [missing-attribute]
_functional_collectives.all_reduce(t, "sum", pg).wait()
finally:
pg.unregister()
Expand Down
Loading