45
45
if TYPE_CHECKING :
46
46
from torchft .process_group import ProcessGroup
47
47
48
- logger : logging .Logger = logging .getLogger (__name__ )
49
-
50
48
MANAGER_ADDR_KEY : str = "manager_addr"
51
49
MANAGER_DEFAULT_PORT : int = int (os .environ .get ("TORCHFT_MANAGER_PORT" , 29511 ))
50
+ REPLICA_ID_KEY : str = "replica_id"
52
51
53
52
T = TypeVar ("T" )
54
53
@@ -132,7 +131,9 @@ def _manager_state_dict() -> Dict[str, T]:
132
131
}
133
132
134
133
self ._ckpt_server = CheckpointServer [Dict [str , T ]](_manager_state_dict )
135
- self ._executor = ThreadPoolExecutor (max_workers = 1 )
134
+ self ._executor = ThreadPoolExecutor (
135
+ max_workers = 1 , thread_name_prefix = "async_quorum"
136
+ )
136
137
self ._quorum_future : Optional [concurrent .futures .Future ] = None
137
138
138
139
self ._store = TCPStore (
@@ -163,10 +164,16 @@ def _manager_state_dict() -> Dict[str, T]:
163
164
)
164
165
165
166
self ._store .set (MANAGER_ADDR_KEY , addr )
167
+ self ._store .set (REPLICA_ID_KEY , replica_id )
166
168
167
169
addr = self ._store .get (MANAGER_ADDR_KEY ).decode ("utf-8" )
168
170
self ._client = ManagerClient (addr , timeout = timeout )
169
171
172
+ replica_id = self ._store .get (REPLICA_ID_KEY ).decode ("utf-8" )
173
+ self ._logger = _ManagerLogger (
174
+ manager = self , replica_id = replica_id or "" , rank = rank
175
+ )
176
+
170
177
self ._step = 0
171
178
self ._quorum_id = - 1
172
179
self ._errored : Optional [Exception ] = None
@@ -230,6 +237,7 @@ def callback(
230
237
) -> torch .Tensor :
231
238
nonlocal grad
232
239
240
+ # check for exceptions
233
241
fut .value ()
234
242
235
243
grad /= self .num_participants ()
@@ -241,7 +249,9 @@ def callback(
241
249
return fut
242
250
243
251
except Exception as e :
244
- logger .exception (f"got exception in all reduce -- skipping remaining: { e } " )
252
+ self ._logger .exception (
253
+ f"got exception in all reduce -- skipping remaining: { e } "
254
+ )
245
255
self .report_error (e )
246
256
247
257
fut = torch .futures .Future () # pyre-fixme[29]: not a function
@@ -294,7 +304,9 @@ def callback(
294
304
try :
295
305
return fut .value ()
296
306
except Exception as e :
297
- logger .exception (f"got exception in future -- skipping remaining: { e } " )
307
+ self ._logger .exception (
308
+ f"got exception in future -- skipping remaining: { e } "
309
+ )
298
310
self .report_error (e )
299
311
return default
300
312
@@ -328,12 +340,13 @@ def step(self) -> None:
328
340
if not self ._use_async_quorum :
329
341
self ._quorum_future .result ()
330
342
331
- # eagerly apply pending state_dict so we can run the forwards pass
332
- self ._apply_pending_state_dict ()
343
+ if self ._healing :
344
+ # eagerly apply pending state_dict so we can run the forwards pass
345
+ self ._apply_pending_state_dict ()
333
346
334
- # we are forcing healing at the beginning so we're in a good state
335
- # and don't need to zero_grad
336
- self ._healing = False
347
+ # we are forcing healing at the beginning so we're in a good state
348
+ # and don't need to zero_grad
349
+ self ._healing = False
337
350
338
351
def _async_quorum (self ) -> None :
339
352
(
@@ -374,21 +387,23 @@ def _async_quorum(self) -> None:
374
387
self ._participating_rank = None
375
388
376
389
if quorum_id != self ._quorum_id :
377
- logger .info (f"{ replica_rank = } reconfiguring for quorum_id { quorum_id } " )
378
390
store_prefixed_addr = f"{ store_address } /torchft/{ quorum_id } /{ self ._rank } "
391
+
392
+ self ._logger .info (f"reconfiguring for { quorum_id = } { store_prefixed_addr = } " )
379
393
# We use the replica rank and world as we want all replicas in the PG.
380
394
self ._pg .configure (store_prefixed_addr , replica_rank , replica_world_size )
381
395
self ._quorum_id = quorum_id
382
396
383
397
# See manager.rs for healing conditions
384
398
if heal :
385
399
self ._healing = True
386
- logger . info (f" { replica_rank } = healing required" )
387
-
388
- logger . info ( f"fetching checkpoint server address from { address } " )
400
+ self . _logger . info (
401
+ f"healing required, fetching checkpoint server address from { address = } { max_step = } "
402
+ )
389
403
primary_client = ManagerClient (address , timeout = self ._timeout )
390
404
checkpoint_server_address = primary_client .checkpoint_address (self ._rank )
391
405
406
+ self ._logger .info (f"fetching checkpoint from { checkpoint_server_address = } " )
392
407
self ._pending_state_dict = CheckpointServer .load_from_address (
393
408
checkpoint_server_address
394
409
)
@@ -406,8 +421,9 @@ def _apply_pending_state_dict(self) -> None:
406
421
assert self ._quorum_future is not None , "must call step before should_commit"
407
422
self ._quorum_future .result ()
408
423
409
- assert self ._pending_state_dict is not None , "checkpoint was not staged"
424
+ self ._logger . info ( "applying pending state dict" )
410
425
426
+ assert self ._pending_state_dict is not None , "checkpoint was not staged"
411
427
self ._load_state_dict (self ._pending_state_dict ["user" ])
412
428
self ._pending_state_dict = None
413
429
@@ -450,7 +466,7 @@ def should_commit(self) -> bool:
450
466
should_commit = self ._client .should_commit (
451
467
self ._rank , self ._step , local_should_commit
452
468
)
453
- logger .info (
469
+ self . _logger .info (
454
470
f"should_commit={ should_commit } enough_replicas={ enough_replicas } , errored={ self ._errored } "
455
471
)
456
472
@@ -534,3 +550,25 @@ def is_participating(self) -> bool:
534
550
assert self ._use_async_quorum
535
551
return False
536
552
return True
553
+
554
+
555
+ class _ManagerLogger :
556
+ def __init__ (self , manager : Manager , replica_id : str , rank : int ) -> None :
557
+ self ._logger : logging .Logger = logging .getLogger (__name__ )
558
+ self ._replica_id = replica_id
559
+ self ._rank = rank
560
+ self ._manager = manager
561
+
562
+ def prefix (self ) -> str :
563
+ return (
564
+ f"[{ self ._replica_id } /{ self ._rank } - step { self ._manager .current_step ()} ]"
565
+ )
566
+
567
+ def info (self , msg : str ) -> None :
568
+ self ._logger .info (f"{ self .prefix ()} { msg } " )
569
+
570
+ def warn (self , msg : str ) -> None :
571
+ self ._logger .warn (f"{ self .prefix ()} { msg } " )
572
+
573
+ def exception (self , msg : str ) -> None :
574
+ self ._logger .exception (f"{ self .prefix ()} { msg } " )
0 commit comments