Skip to content

Commit e4c8e5a

Browse files
authored
pyre strict (#29)
1 parent 9878980 commit e4c8e5a

10 files changed

+136
-87
lines changed

.pyre_configuration

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
2+
"strict": true,
23
"site_package_search_strategy": "pep561",
34
"source_directories": [
45
{

torchft/data.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.utils import data
2121

2222

23+
# pyre-fixme[24]: expected generic parameter
2324
class DistributedSampler(data.distributed.DistributedSampler):
2425
"""
2526
DistributedSampler extends the standard PyTorch DistributedSampler with a
@@ -49,7 +50,7 @@ def __init__(
4950
num_replica_groups: int,
5051
rank: Optional[int] = None,
5152
num_replicas: Optional[int] = None,
52-
**kwargs,
53+
**kwargs: object,
5354
) -> None:
5455
"""
5556
Args:
@@ -64,12 +65,13 @@ def __init__(
6465
if num_replicas is None:
6566
num_replicas = dist.get_world_size()
6667

67-
self.global_rank = rank + num_replicas * replica_group
68-
self.global_world_size = num_replicas * num_replica_groups
68+
self.global_rank: int = rank + num_replicas * replica_group
69+
self.global_world_size: int = num_replicas * num_replica_groups
6970

7071
super().__init__(
7172
dataset,
7273
rank=self.global_rank,
7374
num_replicas=self.global_world_size,
75+
# pyre-fixme[6]: got object
7476
**kwargs,
7577
)

torchft/ddp.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,17 @@ class DistributedDataParallel(parallel.DistributedDataParallel):
4444
same across workers.
4545
"""
4646

47-
def __init__(self, manager: "Manager", module: nn.Module, **args) -> None:
47+
def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> None:
4848
# use a dummy PG to soak up the init all reduce, actual comms will go
4949
# through the comm_hook.
5050
pg = ProcessGroupDummy(0, 1)
5151

52-
super().__init__(module, process_group=pg, **args)
52+
super().__init__(
53+
module,
54+
process_group=pg,
55+
# pyre-fixme[6]: got object
56+
**kwargs,
57+
)
5358

5459
self.register_comm_hook(manager, self._comm_hook)
5560

@@ -70,12 +75,12 @@ class PureDistributedDataParallel(nn.Module):
7075
may be very slow for real models.
7176
"""
7277

73-
def __init__(self, manager: "Manager", module: nn.Module):
78+
def __init__(self, manager: "Manager", module: nn.Module) -> None:
7479
super().__init__()
7580

7681
self.module = module
7782

78-
def post_grad_hook(p):
83+
def post_grad_hook(p: torch.Tensor) -> None:
7984
if p.grad is not None:
8085
manager.allreduce_grad(p.grad)
8186

torchft/ddp_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
class TestDDP(TestCase):
21-
def test_pure_ddp(self):
21+
def test_pure_ddp(self) -> None:
2222
manager = create_autospec(Manager)
2323

2424
m = nn.Linear(3, 4)
@@ -34,7 +34,7 @@ def test_pure_ddp(self):
3434

3535
self.assertEqual(manager.allreduce_grad.call_count, len(list(m.parameters())))
3636

37-
def test_ddp(self):
37+
def test_ddp(self) -> None:
3838
manager = create_autospec(Manager)
3939

4040
call_count = 0

torchft/http.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33

44

55
class _IPv6HTTPServer(ThreadingHTTPServer):
6-
address_family = socket.AF_INET6
7-
request_queue_size = 1024
6+
address_family: socket.AddressFamily = socket.AF_INET6
7+
request_queue_size: int = 1024

torchft/manager.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
2626
"""
2727

28+
import concurrent.futures
2829
import logging
2930
import os
3031
import socket
@@ -35,8 +36,7 @@
3536
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
3637

3738
import torch
38-
from torch.distributed import PrefixStore, ReduceOp, TCPStore, Work
39-
from torch.optim import Optimizer
39+
from torch.distributed import ReduceOp, TCPStore
4040

4141
from torchft.checkpointing import CheckpointServer
4242
from torchft.torchft import Manager as _Manager, ManagerClient
@@ -81,8 +81,8 @@ class Manager:
8181
def __init__(
8282
self,
8383
pg: "ProcessGroup",
84-
load_state_dict: Callable[[object], None],
85-
state_dict: Callable[[], object],
84+
load_state_dict: Callable[[T], None],
85+
state_dict: Callable[[], T],
8686
min_replica_size: int,
8787
port: int = MANAGER_DEFAULT_PORT,
8888
use_async_quorum: bool = True,
@@ -124,14 +124,15 @@ def __init__(
124124
world_size = world_size or int(os.environ["WORLD_SIZE"])
125125
self._min_replica_size = min_replica_size
126126

127-
self._ckpt_server = CheckpointServer(
128-
lambda: {
127+
def _manager_state_dict() -> Dict[str, T]:
128+
return {
129129
"user": state_dict(),
130-
"torchft": self.state_dict(),
130+
"torchft": cast(T, self.state_dict()),
131131
}
132-
)
132+
133+
self._ckpt_server = CheckpointServer[Dict[str, T]](_manager_state_dict)
133134
self._executor = ThreadPoolExecutor(max_workers=1)
134-
self._quorum_future = None
135+
self._quorum_future: Optional[concurrent.futures.Future] = None
135136

136137
self._store = TCPStore(
137138
host_name=store_addr,
@@ -140,7 +141,7 @@ def __init__(
140141
wait_for_workers=False,
141142
)
142143
self._pg = pg
143-
self._manager = None
144+
self._manager: Optional[_Manager] = None
144145

145146
if rank == 0:
146147
hostname = socket.gethostname()
@@ -208,6 +209,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
208209
fut.set_result(grad)
209210
return fut
210211

212+
assert self._quorum_future is not None, "must call step before allreduce_grad"
211213
self._quorum_future.result()
212214

213215
if not self.is_participating():
@@ -397,6 +399,7 @@ def _apply_pending_state_dict(self) -> None:
397399
assert self._healing, "must be in healing state"
398400

399401
# synchronize on future
402+
assert self._quorum_future is not None, "must call step before should_commit"
400403
self._quorum_future.result()
401404

402405
assert self._pending_state_dict is not None, "checkpoint was not staged"

torchft/manager_integ_test.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from concurrent.futures import ThreadPoolExecutor, as_completed
22
from contextlib import ExitStack
3-
from typing import Set, Tuple
3+
from typing import Dict, Set, Tuple
44
from unittest import TestCase
55

66
import torch
@@ -15,14 +15,14 @@
1515

1616

1717
class MyModel(nn.Module):
18-
def __init__(self):
18+
def __init__(self) -> None:
1919
super().__init__()
2020
self.model = nn.Sequential(
2121
nn.Linear(3, 4),
2222
nn.Sigmoid(),
2323
)
2424

25-
def forward(self, x):
25+
def forward(self, x: torch.Tensor) -> torch.Tensor:
2626
return self.model(x)
2727

2828

@@ -52,7 +52,7 @@ def worker_manager(
5252
lighthouse_address: str,
5353
failure_injector: FailureInjector,
5454
attempts: int = 3,
55-
) -> None:
55+
) -> Dict[str, Dict[str, object]]:
5656
for i in range(attempts):
5757
try:
5858
print(f"starting worker {replica_id} attempt {i}")
@@ -65,10 +65,12 @@ def worker_manager(
6565
raise
6666
continue
6767

68+
raise RuntimeError("ran out of attempts")
69+
6870

6971
def train_loop(
7072
replica_id: int, lighthouse_address: str, failure_injector: FailureInjector
71-
) -> None:
73+
) -> Dict[str, Dict[str, object]]:
7274
with ExitStack() as stack:
7375
store = dist.TCPStore(
7476
host_name="localhost",
@@ -77,11 +79,11 @@ def train_loop(
7779
wait_for_workers=False,
7880
)
7981

80-
def load_state_dict(state_dict):
82+
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
8183
m.load_state_dict(state_dict["model"])
8284
optimizer.load_state_dict(state_dict["optim"])
8385

84-
def state_dict():
86+
def state_dict() -> Dict[str, Dict[str, object]]:
8587
return {
8688
"model": m.state_dict(),
8789
"optim": optimizer.state_dict(),
@@ -103,8 +105,10 @@ def state_dict():
103105
)
104106
stack.callback(manager.shutdown)
105107

106-
m = DistributedDataParallel(manager, MyModel())
107-
optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters()))
108+
m: nn.Module = DistributedDataParallel(manager, MyModel())
109+
optimizer: optim.Optimizer = OptimizerWrapper(
110+
manager, optim.Adam(m.parameters())
111+
)
108112
criterion = nn.CrossEntropyLoss()
109113

110114
while True:
@@ -120,14 +124,16 @@ def state_dict():
120124
optimizer.step()
121125

122126
if manager.current_step() >= 5:
123-
# return state_dict so we can check consistency
124-
return state_dict()
127+
break
125128

126129
failure_injector.check(manager.current_step())
127130

131+
# return state_dict so we can check consistency
132+
return state_dict()
133+
128134

129135
class ManagerIntegTest(TestCase):
130-
def test_ddp_healthy(self):
136+
def test_ddp_healthy(self) -> None:
131137
lighthouse = Lighthouse(
132138
bind="[::]:0",
133139
min_replicas=2,
@@ -157,7 +163,7 @@ def test_ddp_healthy(self):
157163
for state_dict in state_dicts:
158164
torch.testing.assert_close(state_dict, state_dicts[0])
159165

160-
def test_ddp_recovery(self):
166+
def test_ddp_recovery(self) -> None:
161167
lighthouse = Lighthouse(
162168
bind="[::]:0",
163169
min_replicas=2,

torchft/manager_test.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ def _create_manager(
5151
return manager
5252

5353
@patch("torchft.manager.ManagerClient", autospec=True)
54-
def test_manager(self, client_mock) -> None:
54+
def test_manager(self, client_mock: MagicMock) -> None:
5555
manager = self._create_manager()
5656
self.assertEqual(client_mock.call_count, 1)
5757

5858
@patch("torchft.manager.ManagerClient", autospec=True)
59-
def test_state_dict(self, client_mock) -> None:
59+
def test_state_dict(self, client_mock: MagicMock) -> None:
6060
manager = self._create_manager()
6161

6262
state_dict = manager.state_dict()
@@ -78,7 +78,7 @@ def test_state_dict(self, client_mock) -> None:
7878
self.assertEqual(manager.batches_committed(), 2345)
7979

8080
@patch("torchft.manager.ManagerClient", autospec=True)
81-
def test_quorum_happy(self, client_mock) -> None:
81+
def test_quorum_happy(self, client_mock: MagicMock) -> None:
8282
manager = self._create_manager()
8383
client_mock().should_commit = lambda rank, step, should_commit: should_commit
8484

@@ -113,7 +113,7 @@ def test_quorum_happy(self, client_mock) -> None:
113113
self.assertEqual(manager.batches_committed(), 2)
114114

115115
@patch("torchft.manager.ManagerClient", autospec=True)
116-
def test_quorum_heal_sync(self, client_mock) -> None:
116+
def test_quorum_heal_sync(self, client_mock: MagicMock) -> None:
117117
manager = self._create_manager(use_async_quorum=False)
118118
client_mock().should_commit = lambda rank, step, should_commit: should_commit
119119

@@ -153,7 +153,9 @@ def test_quorum_heal_sync(self, client_mock) -> None:
153153
self.assertEqual(self.load_state_dict.call_count, 1)
154154

155155
@patch("torchft.manager.ManagerClient", autospec=True)
156-
def test_quorum_heal_async_not_enough_participants(self, client_mock) -> None:
156+
def test_quorum_heal_async_not_enough_participants(
157+
self, client_mock: MagicMock
158+
) -> None:
157159
manager = self._create_manager(use_async_quorum=True, min_replica_size=2)
158160
client_mock().should_commit = lambda rank, step, should_commit: should_commit
159161

@@ -177,6 +179,7 @@ def test_quorum_heal_async_not_enough_participants(self, client_mock) -> None:
177179
self.assertEqual(manager._step, 0)
178180

179181
manager.step()
182+
assert manager._quorum_future is not None
180183
manager._quorum_future.result()
181184
self.assertTrue(manager._healing)
182185
self.assertFalse(manager.is_participating())
@@ -204,7 +207,7 @@ def test_quorum_heal_async_not_enough_participants(self, client_mock) -> None:
204207
self.assertEqual(manager.batches_committed(), 0)
205208

206209
@patch("torchft.manager.ManagerClient", autospec=True)
207-
def test_quorum_heal_async_zero_grad(self, client_mock) -> None:
210+
def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None:
208211
manager = self._create_manager(use_async_quorum=True, min_replica_size=1)
209212
client_mock().should_commit = lambda rank, step, should_commit: should_commit
210213

@@ -228,6 +231,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock) -> None:
228231
self.assertEqual(manager._step, 0)
229232

230233
manager.step()
234+
assert manager._quorum_future is not None
231235
manager._quorum_future.result()
232236
self.assertTrue(manager._healing)
233237

@@ -253,7 +257,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock) -> None:
253257
self.assertEqual(manager.batches_committed(), 1)
254258

255259
@patch("torchft.manager.ManagerClient", autospec=True)
256-
def test_allreduce_error(self, client_mock) -> None:
260+
def test_allreduce_error(self, client_mock: MagicMock) -> None:
257261
manager = self._create_manager()
258262
client_mock().should_commit = lambda rank, step, should_commit: should_commit
259263

@@ -338,7 +342,7 @@ def test_allreduce_error(self, client_mock) -> None:
338342
self.assertTrue(manager.should_commit())
339343

340344
@patch("torchft.manager.ManagerClient", autospec=True)
341-
def test_quorum_fixed_world_size(self, client_mock) -> None:
345+
def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None:
342346
# test active and spares
343347
for rank in [1, 2]:
344348
manager = self._create_manager(
@@ -375,15 +379,15 @@ def test_quorum_fixed_world_size(self, client_mock) -> None:
375379
self.assertEqual(manager.batches_committed(), 2)
376380

377381
@patch("torchft.manager.ManagerClient", autospec=True)
378-
def test_manager_report_error(self, client_mock) -> None:
382+
def test_manager_report_error(self, client_mock: MagicMock) -> None:
379383
manager = self._create_manager()
380384

381385
self.assertFalse(manager.errored())
382386
manager.report_error()
383387
self.assertTrue(manager.errored())
384388

385389
@patch("torchft.manager.ManagerClient", autospec=True)
386-
def test_manager_wrap_future(self, client_mock) -> None:
390+
def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
387391
manager = self._create_manager()
388392

389393
self.assertFalse(manager.errored())
@@ -398,7 +402,7 @@ def test_manager_wrap_future(self, client_mock) -> None:
398402
self.assertEqual(manager._pending_work, [wrapped_fut])
399403

400404
@patch("torchft.manager.ManagerClient", autospec=True)
401-
def test_manager_numerics(self, client_mock) -> None:
405+
def test_manager_numerics(self, client_mock: MagicMock) -> None:
402406
manager = self._create_manager()
403407

404408
manager._quorum_future = MagicMock()

0 commit comments

Comments
 (0)