Skip to content

Commit 51697a3

Browse files
committed
tf/estimator.py: only write checkpoint in rank0 (#447)
1 parent 0c9e540 commit 51697a3

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

python/raydp/tf/estimator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,14 @@ def train_func(config):
188188
if config["evaluate"]:
189189
test_history = multi_worker_model.evaluate(eval_tf_dataset, callbacks=callbacks)
190190
results.append(test_history)
191+
192+
# Only save checkpoint from the chief worker to avoid race conditions.
193+
# However, we need to call save on all workers to avoid deadlock.
191194
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
192195
multi_worker_model.save(temp_checkpoint_dir, save_format="tf")
193-
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
196+
checkpoint = None
197+
if session.get_world_rank() == 0:
198+
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
194199

195200
session.report({}, checkpoint=checkpoint)
196201

0 commit comments

Comments
 (0)