Skip to content

Commit 79c92e3

Browse files
committed
v0.0.3.post2
1 parent d19e71a commit 79c92e3

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

torch4keras/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,11 @@ def fit(self, train_dataloader, steps_per_epoch=None, epochs=1, callbacks=None,
140140
if not isinstance(callbacks, (list, tuple)):
141141
callbacks = [callbacks]
142142
for callback in callbacks:
143-
assert isinstance(callback, Callback), "Args 'callbacks' only support Callback() inputs"
143+
assert isinstance(callback, Callback), "Args `callbacks` only support Callback() inputs"
144144
progbarlogger = ProgbarLogger(stateful_metrics=self.stateful_metrics) # 进度条
145145
history = History()
146-
self.callbacks = CallbackList([BaseLogger(self.stateful_metrics), progbarlogger] + callbacks + [history])
146+
master_rank = self.master_rank if hasattr(self, 'master_rank') else None
147+
self.callbacks = CallbackList([BaseLogger(self.stateful_metrics), progbarlogger] + callbacks + [history], master_rank=master_rank)
147148
callback_model = self.module if hasattr(self, 'module') else self
148149
self.callbacks.set_model(callback_model)
149150
self.callbacks.set_params({

torch4keras/snippets.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,7 @@ def __init__(self, callbacks=None, queue_length=10, master_rank=None):
164164
callbacks = callbacks or []
165165
self.callbacks = [c for c in callbacks]
166166
self.queue_length = queue_length
167-
if master_rank:
168-
self.master_rank = master_rank
167+
self.master_rank = master_rank
169168

170169
def append(self, callback):
171170
self.callbacks.append(callback)
@@ -178,15 +177,10 @@ def set_model(self, model):
178177
for callback in self.callbacks:
179178
callback.set_model(model)
180179

181-
def return_distributed(self):
182-
'''分布式训练中,控制只有master_rank才执行callbacks
183-
'''
180+
def on_epoch_begin(self, global_step, epoch, logs=None):
184181
# 如果是分布式DDP训练,则仅masker_rank可以callback
185-
if hasattr(self, 'master_rank') and self.master_rank!=torch.distributed.get_rank():
182+
if (self.master_rank is not None) and (self.master_rank!=torch.distributed.get_rank()):
186183
return
187-
188-
def on_epoch_begin(self, global_step, epoch, logs=None):
189-
self.return_distributed()
190184
logs = logs or {}
191185
for callback in self.callbacks:
192186
callback.on_epoch_begin(global_step, epoch, logs)
@@ -195,13 +189,15 @@ def on_epoch_begin(self, global_step, epoch, logs=None):
195189
self._delta_ts_batch_end = deque([], maxlen=self.queue_length)
196190

197191
def on_epoch_end(self, global_step, epoch, logs=None):
198-
self.return_distributed()
192+
if (self.master_rank is not None) and (self.master_rank!=torch.distributed.get_rank()):
193+
return
199194
logs = logs or {}
200195
for callback in self.callbacks:
201196
callback.on_epoch_end(global_step, epoch, logs)
202197

203198
def on_batch_begin(self, global_step, local_step, logs=None):
204-
self.return_distributed()
199+
if (self.master_rank is not None) and (self.master_rank!=torch.distributed.get_rank()):
200+
return
205201
logs = logs or {}
206202
t_before_callbacks = time.time()
207203
for callback in self.callbacks:
@@ -213,7 +209,8 @@ def on_batch_begin(self, global_step, local_step, logs=None):
213209
self._t_enter_batch = time.time()
214210

215211
def on_batch_end(self, global_step, local_step, logs=None):
216-
self.return_distributed()
212+
if (self.master_rank is not None) and (self.master_rank!=torch.distributed.get_rank()):
213+
return
217214
logs = logs or {}
218215
if not hasattr(self, '_t_enter_batch'):
219216
self._t_enter_batch = time.time()
@@ -227,19 +224,22 @@ def on_batch_end(self, global_step, local_step, logs=None):
227224
warnings.warn(f'Method on_batch_end() is slow compared to the batch update {delta_t_median}. Check your callbacks.')
228225

229226
def on_train_begin(self, logs=None):
230-
self.return_distributed()
227+
if (self.master_rank is not None) and (self.master_rank!=torch.distributed.get_rank()):
228+
return
231229
logs = logs or {}
232230
for callback in self.callbacks:
233231
callback.on_train_begin(logs)
234232

235233
def on_train_end(self, logs=None):
236-
self.return_distributed()
234+
if (self.master_rank is not None) and (self.master_rank!=torch.distributed.get_rank()):
235+
return
237236
logs = logs or {}
238237
for callback in self.callbacks:
239238
callback.on_train_end(logs)
240239

241240
def on_dataloader_end(self, logs=None):
242-
self.return_distributed()
241+
if (self.master_rank is not None) and (self.master_rank!=torch.distributed.get_rank()):
242+
return
243243
logs = logs or {}
244244
for callback in self.callbacks:
245245
callback.on_dataloader_end(logs)

0 commit comments

Comments
 (0)