1
1
from concurrent .futures import ThreadPoolExecutor , as_completed
2
+ from contextlib import ExitStack
3
+ from typing import Set , Tuple
2
4
from unittest import TestCase
3
5
4
6
import torch
@@ -24,63 +26,108 @@ def forward(self, x):
24
26
return self .model (x )
25
27
26
28
27
- def train_loop (replica_id : int , lighthouse_address : str ) -> None :
28
- store = dist .TCPStore (
29
- host_name = "localhost" ,
30
- port = 0 ,
31
- is_master = True ,
32
- wait_for_workers = False ,
33
- )
34
-
35
- def load_state_dict (state_dict ):
36
- m .load_state_dict (state_dict ["model" ])
37
- optimizer .load_state_dict (state_dict ["optim" ])
38
-
39
- def state_dict ():
40
- return {
41
- "model" : m .state_dict (),
42
- "optim" : optimizer .state_dict (),
43
- }
44
-
45
- pg = ProcessGroupGloo ()
46
- manager = Manager (
47
- pg = pg ,
48
- min_replica_size = 2 ,
49
- load_state_dict = load_state_dict ,
50
- state_dict = state_dict ,
51
- replica_id = str (replica_id ),
52
- store_addr = "localhost" ,
53
- store_port = store .port ,
54
- rank = 0 ,
55
- world_size = 1 ,
56
- lighthouse_addr = lighthouse_address ,
57
- port = 19530 + replica_id ,
58
- )
59
- m = DistributedDataParallel (manager , MyModel ())
60
- optimizer = OptimizerWrapper (manager , optim .Adam (m .parameters ()))
61
- criterion = nn .CrossEntropyLoss ()
62
-
63
- while True :
64
- inputs = torch .rand (2 , 3 )
65
- labels = torch .randint (4 , (2 ,))
66
-
67
- optimizer .zero_grad ()
68
- out = m (inputs )
69
- loss = criterion (out , labels )
70
-
71
- loss .backward ()
72
- optimizer .step ()
73
-
74
- # TODO: assert weights are equal across replicas
75
-
76
- if manager .current_step () >= 5 :
77
- break
78
-
79
- manager .shutdown ()
29
+ class InjectedFailure (Exception ):
30
+ pass
31
+
32
+
33
+ class FailureInjector :
34
+ def __init__ (self ) -> None :
35
+ self ._failures : Set [int ] = set ()
36
+ self .count = 0
37
+
38
+ def fail_at (self , step : int ) -> "FailureInjector" :
39
+ self ._failures .add (step )
40
+ return self
41
+
42
+ def check (self , step : int ) -> None :
43
+ if step in self ._failures :
44
+ self .count += 1
45
+ self ._failures .remove (step )
46
+ print (f"injecting failure { step = } " )
47
+ raise InjectedFailure (f"injected failure { step = } " )
48
+
49
+
50
+ def worker_manager (
51
+ replica_id : int ,
52
+ lighthouse_address : str ,
53
+ failure_injector : FailureInjector ,
54
+ attempts : int = 3 ,
55
+ ) -> None :
56
+ for i in range (attempts ):
57
+ try :
58
+ print (f"starting worker { replica_id } attempt { i } " )
59
+ return train_loop (
60
+ replica_id , lighthouse_address , failure_injector = failure_injector
61
+ )
62
+ except InjectedFailure as e :
63
+ print ("got injected failure" , i , e )
64
+ if i == attempts - 1 :
65
+ raise
66
+ continue
67
+
68
+
69
+ def train_loop (
70
+ replica_id : int , lighthouse_address : str , failure_injector : FailureInjector
71
+ ) -> None :
72
+ with ExitStack () as stack :
73
+ store = dist .TCPStore (
74
+ host_name = "localhost" ,
75
+ port = 0 ,
76
+ is_master = True ,
77
+ wait_for_workers = False ,
78
+ )
79
+
80
+ def load_state_dict (state_dict ):
81
+ m .load_state_dict (state_dict ["model" ])
82
+ optimizer .load_state_dict (state_dict ["optim" ])
83
+
84
+ def state_dict ():
85
+ return {
86
+ "model" : m .state_dict (),
87
+ "optim" : optimizer .state_dict (),
88
+ }
89
+
90
+ pg = ProcessGroupGloo ()
91
+ manager = Manager (
92
+ pg = pg ,
93
+ min_replica_size = 2 ,
94
+ load_state_dict = load_state_dict ,
95
+ state_dict = state_dict ,
96
+ replica_id = str (replica_id ),
97
+ store_addr = "localhost" ,
98
+ store_port = store .port ,
99
+ rank = 0 ,
100
+ world_size = 1 ,
101
+ lighthouse_addr = lighthouse_address ,
102
+ port = 19530 + replica_id ,
103
+ )
104
+ stack .callback (manager .shutdown )
105
+
106
+ m = DistributedDataParallel (manager , MyModel ())
107
+ optimizer = OptimizerWrapper (manager , optim .Adam (m .parameters ()))
108
+ criterion = nn .CrossEntropyLoss ()
109
+
110
+ while True :
111
+ print (f"worker { replica_id } starting step { manager .current_step ()} " )
112
+ inputs = torch .rand (2 , 3 )
113
+ labels = torch .randint (4 , (2 ,))
114
+
115
+ optimizer .zero_grad ()
116
+ out = m (inputs )
117
+ loss = criterion (out , labels )
118
+
119
+ loss .backward ()
120
+ optimizer .step ()
121
+
122
+ if manager .current_step () >= 5 :
123
+ # return state_dict so we can check consistency
124
+ return state_dict ()
125
+
126
+ failure_injector .check (manager .current_step ())
80
127
81
128
82
129
class ManagerIntegTest (TestCase ):
83
- def test_ddp (self ):
130
+ def test_ddp_healthy (self ):
84
131
lighthouse = Lighthouse (
85
132
bind = "[::]:0" ,
86
133
min_replicas = 2 ,
@@ -90,11 +137,60 @@ def test_ddp(self):
90
137
91
138
with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
92
139
for replica_id in range (num_replicas ):
140
+ failure_injector = FailureInjector ()
141
+ futures .append (
142
+ executor .submit (
143
+ worker_manager ,
144
+ replica_id ,
145
+ lighthouse .address (),
146
+ failure_injector = failure_injector ,
147
+ )
148
+ )
149
+
150
+ state_dicts = []
151
+
152
+ for fut in as_completed (futures ):
153
+ state_dicts .append (fut .result ())
154
+
155
+ lighthouse .shutdown ()
156
+
157
+ for state_dict in state_dicts :
158
+ torch .testing .assert_close (state_dict , state_dicts [0 ])
159
+
160
+ def test_ddp_recovery (self ):
161
+ lighthouse = Lighthouse (
162
+ bind = "[::]:0" ,
163
+ min_replicas = 2 ,
164
+ )
165
+ num_replicas = 2
166
+ futures = []
167
+
168
+ failure_injectors = [
169
+ FailureInjector (),
170
+ FailureInjector ().fail_at (2 ),
171
+ ]
172
+
173
+ with ThreadPoolExecutor (max_workers = num_replicas ) as executor :
174
+ for replica_id , failure_injector in zip (
175
+ range (num_replicas ), failure_injectors
176
+ ):
93
177
futures .append (
94
- executor .submit (train_loop , replica_id , lighthouse .address ())
178
+ executor .submit (
179
+ worker_manager ,
180
+ replica_id ,
181
+ lighthouse .address (),
182
+ failure_injector = failure_injector ,
183
+ )
95
184
)
96
185
186
+ state_dicts = []
187
+
97
188
for fut in as_completed (futures ):
98
- fut .result ()
189
+ state_dicts . append ( fut .result () )
99
190
100
191
lighthouse .shutdown ()
192
+
193
+ for state_dict in state_dicts :
194
+ torch .testing .assert_close (state_dict , state_dicts [0 ])
195
+
196
+ self .assertEqual (failure_injectors [1 ].count , 1 )
0 commit comments