Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pyre strict #29

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pyre_configuration
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"strict": true,
"site_package_search_strategy": "pep561",
"source_directories": [
{
Expand Down
8 changes: 5 additions & 3 deletions torchft/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.utils import data


# pyre-fixme[24]: expected generic parameter
class DistributedSampler(data.distributed.DistributedSampler):
"""
DistributedSampler extends the standard PyTorch DistributedSampler with a
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(
num_replica_groups: int,
rank: Optional[int] = None,
num_replicas: Optional[int] = None,
**kwargs,
**kwargs: object,
) -> None:
"""
Args:
Expand All @@ -64,12 +65,13 @@ def __init__(
if num_replicas is None:
num_replicas = dist.get_world_size()

self.global_rank = rank + num_replicas * replica_group
self.global_world_size = num_replicas * num_replica_groups
self.global_rank: int = rank + num_replicas * replica_group
self.global_world_size: int = num_replicas * num_replica_groups

super().__init__(
dataset,
rank=self.global_rank,
num_replicas=self.global_world_size,
# pyre-fixme[6]: got object
**kwargs,
)
13 changes: 9 additions & 4 deletions torchft/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ class DistributedDataParallel(parallel.DistributedDataParallel):
same across workers.
"""

def __init__(self, manager: "Manager", module: nn.Module, **args) -> None:
def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> None:
# use a dummy PG to soak up the init all reduce, actual comms will go
# through the comm_hook.
pg = ProcessGroupDummy(0, 1)

super().__init__(module, process_group=pg, **args)
super().__init__(
module,
process_group=pg,
# pyre-fixme[6]: got object
**kwargs,
)

self.register_comm_hook(manager, self._comm_hook)

Expand All @@ -70,12 +75,12 @@ class PureDistributedDataParallel(nn.Module):
may be very slow for real models.
"""

def __init__(self, manager: "Manager", module: nn.Module):
def __init__(self, manager: "Manager", module: nn.Module) -> None:
super().__init__()

self.module = module

def post_grad_hook(p):
def post_grad_hook(p: torch.Tensor) -> None:
if p.grad is not None:
manager.allreduce_grad(p.grad)

Expand Down
4 changes: 2 additions & 2 deletions torchft/ddp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class TestDDP(TestCase):
def test_pure_ddp(self):
def test_pure_ddp(self) -> None:
manager = create_autospec(Manager)

m = nn.Linear(3, 4)
Expand All @@ -34,7 +34,7 @@ def test_pure_ddp(self):

self.assertEqual(manager.allreduce_grad.call_count, len(list(m.parameters())))

def test_ddp(self):
def test_ddp(self) -> None:
manager = create_autospec(Manager)

call_count = 0
Expand Down
4 changes: 2 additions & 2 deletions torchft/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@


class _IPv6HTTPServer(ThreadingHTTPServer):
address_family = socket.AF_INET6
request_queue_size = 1024
address_family: socket.AddressFamily = socket.AF_INET6
request_queue_size: int = 1024
23 changes: 13 additions & 10 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

"""

import concurrent.futures
import logging
import os
import socket
Expand All @@ -35,8 +36,7 @@
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast

import torch
from torch.distributed import PrefixStore, ReduceOp, TCPStore, Work
from torch.optim import Optimizer
from torch.distributed import ReduceOp, TCPStore

from torchft.checkpointing import CheckpointServer
from torchft.torchft import Manager as _Manager, ManagerClient
Expand Down Expand Up @@ -81,8 +81,8 @@ class Manager:
def __init__(
self,
pg: "ProcessGroup",
load_state_dict: Callable[[object], None],
state_dict: Callable[[], object],
load_state_dict: Callable[[T], None],
state_dict: Callable[[], T],
min_replica_size: int,
port: int = MANAGER_DEFAULT_PORT,
use_async_quorum: bool = True,
Expand Down Expand Up @@ -124,14 +124,15 @@ def __init__(
world_size = world_size or int(os.environ["WORLD_SIZE"])
self._min_replica_size = min_replica_size

self._ckpt_server = CheckpointServer(
lambda: {
def _manager_state_dict() -> Dict[str, T]:
return {
"user": state_dict(),
"torchft": self.state_dict(),
"torchft": cast(T, self.state_dict()),
}
)

self._ckpt_server = CheckpointServer[Dict[str, T]](_manager_state_dict)
self._executor = ThreadPoolExecutor(max_workers=1)
self._quorum_future = None
self._quorum_future: Optional[concurrent.futures.Future] = None

self._store = TCPStore(
host_name=store_addr,
Expand All @@ -140,7 +141,7 @@ def __init__(
wait_for_workers=False,
)
self._pg = pg
self._manager = None
self._manager: Optional[_Manager] = None

if rank == 0:
hostname = socket.gethostname()
Expand Down Expand Up @@ -208,6 +209,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
fut.set_result(grad)
return fut

assert self._quorum_future is not None, "must call step before allreduce_grad"
self._quorum_future.result()

if not self.is_participating():
Expand Down Expand Up @@ -397,6 +399,7 @@ def _apply_pending_state_dict(self) -> None:
assert self._healing, "must be in healing state"

# synchronize on future
assert self._quorum_future is not None, "must call step before should_commit"
self._quorum_future.result()

assert self._pending_state_dict is not None, "checkpoint was not staged"
Expand Down
32 changes: 19 additions & 13 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
from typing import Set, Tuple
from typing import Dict, Set, Tuple
from unittest import TestCase

import torch
Expand All @@ -15,14 +15,14 @@


class MyModel(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.model = nn.Sequential(
nn.Linear(3, 4),
nn.Sigmoid(),
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)


Expand Down Expand Up @@ -52,7 +52,7 @@ def worker_manager(
lighthouse_address: str,
failure_injector: FailureInjector,
attempts: int = 3,
) -> None:
) -> Dict[str, Dict[str, object]]:
for i in range(attempts):
try:
print(f"starting worker {replica_id} attempt {i}")
Expand All @@ -65,10 +65,12 @@ def worker_manager(
raise
continue

raise RuntimeError("ran out of attempts")


def train_loop(
replica_id: int, lighthouse_address: str, failure_injector: FailureInjector
) -> None:
) -> Dict[str, Dict[str, object]]:
with ExitStack() as stack:
store = dist.TCPStore(
host_name="localhost",
Expand All @@ -77,11 +79,11 @@ def train_loop(
wait_for_workers=False,
)

def load_state_dict(state_dict):
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
m.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optim"])

def state_dict():
def state_dict() -> Dict[str, Dict[str, object]]:
return {
"model": m.state_dict(),
"optim": optimizer.state_dict(),
Expand All @@ -103,8 +105,10 @@ def state_dict():
)
stack.callback(manager.shutdown)

m = DistributedDataParallel(manager, MyModel())
optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters()))
m: nn.Module = DistributedDataParallel(manager, MyModel())
optimizer: optim.Optimizer = OptimizerWrapper(
manager, optim.Adam(m.parameters())
)
criterion = nn.CrossEntropyLoss()

while True:
Expand All @@ -120,14 +124,16 @@ def state_dict():
optimizer.step()

if manager.current_step() >= 5:
# return state_dict so we can check consistency
return state_dict()
break

failure_injector.check(manager.current_step())

# return state_dict so we can check consistency
return state_dict()


class ManagerIntegTest(TestCase):
def test_ddp_healthy(self):
def test_ddp_healthy(self) -> None:
lighthouse = Lighthouse(
bind="[::]:0",
min_replicas=2,
Expand Down Expand Up @@ -157,7 +163,7 @@ def test_ddp_healthy(self):
for state_dict in state_dicts:
torch.testing.assert_close(state_dict, state_dicts[0])

def test_ddp_recovery(self):
def test_ddp_recovery(self) -> None:
lighthouse = Lighthouse(
bind="[::]:0",
min_replicas=2,
Expand Down
26 changes: 15 additions & 11 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def _create_manager(
return manager

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager(self, client_mock) -> None:
def test_manager(self, client_mock: MagicMock) -> None:
manager = self._create_manager()
self.assertEqual(client_mock.call_count, 1)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_state_dict(self, client_mock) -> None:
def test_state_dict(self, client_mock: MagicMock) -> None:
manager = self._create_manager()

state_dict = manager.state_dict()
Expand All @@ -78,7 +78,7 @@ def test_state_dict(self, client_mock) -> None:
self.assertEqual(manager.batches_committed(), 2345)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_happy(self, client_mock) -> None:
def test_quorum_happy(self, client_mock: MagicMock) -> None:
manager = self._create_manager()
client_mock().should_commit = lambda rank, step, should_commit: should_commit

Expand Down Expand Up @@ -113,7 +113,7 @@ def test_quorum_happy(self, client_mock) -> None:
self.assertEqual(manager.batches_committed(), 2)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_heal_sync(self, client_mock) -> None:
def test_quorum_heal_sync(self, client_mock: MagicMock) -> None:
manager = self._create_manager(use_async_quorum=False)
client_mock().should_commit = lambda rank, step, should_commit: should_commit

Expand Down Expand Up @@ -153,7 +153,9 @@ def test_quorum_heal_sync(self, client_mock) -> None:
self.assertEqual(self.load_state_dict.call_count, 1)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_heal_async_not_enough_participants(self, client_mock) -> None:
def test_quorum_heal_async_not_enough_participants(
self, client_mock: MagicMock
) -> None:
manager = self._create_manager(use_async_quorum=True, min_replica_size=2)
client_mock().should_commit = lambda rank, step, should_commit: should_commit

Expand All @@ -177,6 +179,7 @@ def test_quorum_heal_async_not_enough_participants(self, client_mock) -> None:
self.assertEqual(manager._step, 0)

manager.step()
assert manager._quorum_future is not None
manager._quorum_future.result()
self.assertTrue(manager._healing)
self.assertFalse(manager.is_participating())
Expand Down Expand Up @@ -204,7 +207,7 @@ def test_quorum_heal_async_not_enough_participants(self, client_mock) -> None:
self.assertEqual(manager.batches_committed(), 0)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_heal_async_zero_grad(self, client_mock) -> None:
def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None:
manager = self._create_manager(use_async_quorum=True, min_replica_size=1)
client_mock().should_commit = lambda rank, step, should_commit: should_commit

Expand All @@ -228,6 +231,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock) -> None:
self.assertEqual(manager._step, 0)

manager.step()
assert manager._quorum_future is not None
manager._quorum_future.result()
self.assertTrue(manager._healing)

Expand All @@ -253,7 +257,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock) -> None:
self.assertEqual(manager.batches_committed(), 1)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_allreduce_error(self, client_mock) -> None:
def test_allreduce_error(self, client_mock: MagicMock) -> None:
manager = self._create_manager()
client_mock().should_commit = lambda rank, step, should_commit: should_commit

Expand Down Expand Up @@ -338,7 +342,7 @@ def test_allreduce_error(self, client_mock) -> None:
self.assertTrue(manager.should_commit())

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_fixed_world_size(self, client_mock) -> None:
def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None:
# test active and spares
for rank in [1, 2]:
manager = self._create_manager(
Expand Down Expand Up @@ -375,15 +379,15 @@ def test_quorum_fixed_world_size(self, client_mock) -> None:
self.assertEqual(manager.batches_committed(), 2)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_report_error(self, client_mock) -> None:
def test_manager_report_error(self, client_mock: MagicMock) -> None:
manager = self._create_manager()

self.assertFalse(manager.errored())
manager.report_error()
self.assertTrue(manager.errored())

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_wrap_future(self, client_mock) -> None:
def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
manager = self._create_manager()

self.assertFalse(manager.errored())
Expand All @@ -398,7 +402,7 @@ def test_manager_wrap_future(self, client_mock) -> None:
self.assertEqual(manager._pending_work, [wrapped_fut])

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_numerics(self, client_mock) -> None:
def test_manager_numerics(self, client_mock: MagicMock) -> None:
manager = self._create_manager()

manager._quorum_future = MagicMock()
Expand Down
Loading
Loading