Skip to content

Commit afdf80e

Browse files
committed
[Embedding] Check the sharded property of tf.train.Saver.
Signed-off-by: chenbangduo.cbd <[email protected]>
1 parent 93c69ad commit afdf80e

File tree

18 files changed

+33
-32
lines changed

18 files changed

+33
-32
lines changed

modelzoo/bst/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,10 +612,9 @@ def train(sess_config,
612612
hooks = []
613613
hooks.extend(input_hooks)
614614

615-
sharded_saver = tf_config != None
616615
scaffold = tf.train.Scaffold(
617616
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
618-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
617+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
619618

620619
stop_hook = tf.train.StopAtStepHook(last_step=steps)
621620
log_hook = tf.train.LoggingTensorHook(

modelzoo/dbmtl/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,9 @@ def train(sess_config,
527527
hooks = []
528528
hooks.extend(input_hooks)
529529

530-
sharded_saver = tf_config != None
531530
scaffold = tf.train.Scaffold(
532531
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
533-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
532+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
534533

535534
stop_hook = tf.train.StopAtStepHook(last_step=steps)
536535
log_hook = tf.train.LoggingTensorHook(

modelzoo/dcn/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,9 @@ def train(sess_config,
594594
hooks = []
595595
hooks.extend(input_hooks)
596596

597-
sharded_saver = tf_config != None
598597
scaffold = tf.train.Scaffold(
599598
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
600-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
599+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
601600

602601
stop_hook = tf.train.StopAtStepHook(last_step=steps)
603602
log_hook = tf.train.LoggingTensorHook(

modelzoo/dcnv2/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,9 @@ def train(sess_config,
610610
hooks = []
611611
hooks.extend(input_hooks)
612612

613-
sharded_saver = tf_config != None
614613
scaffold = tf.train.Scaffold(
615614
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
616-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
615+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
617616

618617
stop_hook = tf.train.StopAtStepHook(last_step=steps)
619618
log_hook = tf.train.LoggingTensorHook(

modelzoo/deepfm/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,9 @@ def train(sess_config,
472472
hooks = []
473473
hooks.extend(input_hooks)
474474

475-
sharded_saver = tf_config != None
476475
scaffold = tf.train.Scaffold(
477476
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
478-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
477+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
479478

480479
stop_hook = tf.train.StopAtStepHook(last_step=steps)
481480
log_hook = tf.train.LoggingTensorHook(

modelzoo/dien/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,10 +776,9 @@ def train(sess_config,
776776
hooks = []
777777
hooks.extend(input_hooks)
778778

779-
sharded_saver = tf_config != None
780779
scaffold = tf.train.Scaffold(
781780
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
782-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
781+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
783782

784783
stop_hook = tf.train.StopAtStepHook(last_step=steps)
785784
log_hook = tf.train.LoggingTensorHook(

modelzoo/din/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,9 @@ def train(sess_config,
594594
hooks = []
595595
hooks.extend(input_hooks)
596596

597-
sharded_saver = tf_config != None
598597
scaffold = tf.train.Scaffold(
599598
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
600-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
599+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
601600

602601
stop_hook = tf.train.StopAtStepHook(last_step=steps)
603602
log_hook = tf.train.LoggingTensorHook(

modelzoo/dlrm/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,9 @@ def train(sess_config,
507507
hooks = []
508508
hooks.extend(input_hooks)
509509

510-
sharded_saver = tf_config != None
511510
scaffold = tf.train.Scaffold(
512511
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
513-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
512+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
514513

515514
stop_hook = tf.train.StopAtStepHook(last_step=steps)
516515
log_hook = tf.train.LoggingTensorHook(

modelzoo/dssm/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,9 @@ def train(sess_config,
478478
hooks = []
479479
hooks.extend(input_hooks)
480480

481-
sharded_saver = tf_config != None
482481
scaffold = tf.train.Scaffold(
483482
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
484-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
483+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
485484

486485
stop_hook = tf.train.StopAtStepHook(last_step=steps)
487486
log_hook = tf.train.LoggingTensorHook(

modelzoo/esmm/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,10 +534,9 @@ def train(sess_config,
534534
hooks = []
535535
hooks.extend(input_hooks)
536536

537-
sharded_saver = tf_config != None
538537
scaffold = tf.train.Scaffold(
539538
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
540-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
539+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
541540

542541
stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
543542
log_hook = tf.train.LoggingTensorHook(

0 commit comments

Comments
 (0)