@@ -164,8 +164,7 @@ def __init__(self, callbacks=None, queue_length=10, master_rank=None):
164
164
callbacks = callbacks or []
165
165
self .callbacks = [c for c in callbacks ]
166
166
self .queue_length = queue_length
167
- if master_rank :
168
- self .master_rank = master_rank
167
+ self .master_rank = master_rank
169
168
170
169
def append (self , callback ):
171
170
self .callbacks .append (callback )
@@ -178,15 +177,10 @@ def set_model(self, model):
178
177
for callback in self .callbacks :
179
178
callback .set_model (model )
180
179
181
- def return_distributed (self ):
182
- '''分布式训练中,控制只有master_rank才执行callbacks
183
- '''
180
+ def on_epoch_begin (self , global_step , epoch , logs = None ):
184
181
# 如果是分布式DDP训练,则仅masker_rank可以callback
185
- if hasattr (self , ' master_rank' ) and self .master_rank != torch .distributed .get_rank ():
182
+ if (self . master_rank is not None ) and ( self .master_rank != torch .distributed .get_rank () ):
186
183
return
187
-
188
- def on_epoch_begin (self , global_step , epoch , logs = None ):
189
- self .return_distributed ()
190
184
logs = logs or {}
191
185
for callback in self .callbacks :
192
186
callback .on_epoch_begin (global_step , epoch , logs )
@@ -195,13 +189,15 @@ def on_epoch_begin(self, global_step, epoch, logs=None):
195
189
self ._delta_ts_batch_end = deque ([], maxlen = self .queue_length )
196
190
197
191
def on_epoch_end (self , global_step , epoch , logs = None ):
198
- self .return_distributed ()
192
+ if (self .master_rank is not None ) and (self .master_rank != torch .distributed .get_rank ()):
193
+ return
199
194
logs = logs or {}
200
195
for callback in self .callbacks :
201
196
callback .on_epoch_end (global_step , epoch , logs )
202
197
203
198
def on_batch_begin (self , global_step , local_step , logs = None ):
204
- self .return_distributed ()
199
+ if (self .master_rank is not None ) and (self .master_rank != torch .distributed .get_rank ()):
200
+ return
205
201
logs = logs or {}
206
202
t_before_callbacks = time .time ()
207
203
for callback in self .callbacks :
@@ -213,7 +209,8 @@ def on_batch_begin(self, global_step, local_step, logs=None):
213
209
self ._t_enter_batch = time .time ()
214
210
215
211
def on_batch_end (self , global_step , local_step , logs = None ):
216
- self .return_distributed ()
212
+ if (self .master_rank is not None ) and (self .master_rank != torch .distributed .get_rank ()):
213
+ return
217
214
logs = logs or {}
218
215
if not hasattr (self , '_t_enter_batch' ):
219
216
self ._t_enter_batch = time .time ()
@@ -227,19 +224,22 @@ def on_batch_end(self, global_step, local_step, logs=None):
227
224
warnings .warn (f'Method on_batch_end() is slow compared to the batch update { delta_t_median } . Check your callbacks.' )
228
225
229
226
def on_train_begin (self , logs = None ):
230
- self .return_distributed ()
227
+ if (self .master_rank is not None ) and (self .master_rank != torch .distributed .get_rank ()):
228
+ return
231
229
logs = logs or {}
232
230
for callback in self .callbacks :
233
231
callback .on_train_begin (logs )
234
232
235
233
def on_train_end (self , logs = None ):
236
- self .return_distributed ()
234
+ if (self .master_rank is not None ) and (self .master_rank != torch .distributed .get_rank ()):
235
+ return
237
236
logs = logs or {}
238
237
for callback in self .callbacks :
239
238
callback .on_train_end (logs )
240
239
241
240
def on_dataloader_end (self , logs = None ):
242
- self .return_distributed ()
241
+ if (self .master_rank is not None ) and (self .master_rank != torch .distributed .get_rank ()):
242
+ return
243
243
logs = logs or {}
244
244
for callback in self .callbacks :
245
245
callback .on_dataloader_end (logs )
0 commit comments