Skip to content

Commit 5da3e0e

Browse files
committed
[Embedding] Check the sharded property of tf.train.Saver.
Signed-off-by: chenbangduo.cbd <[email protected]>
1 parent 93c69ad commit 5da3e0e

File tree

22 files changed

+74
-70
lines changed

22 files changed

+74
-70
lines changed

modelzoo/bst/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -612,10 +612,9 @@ def train(sess_config,
612612
hooks = []
613613
hooks.extend(input_hooks)
614614

615-
sharded_saver = tf_config != None
616615
scaffold = tf.train.Scaffold(
617616
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
618-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
617+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
619618

620619
stop_hook = tf.train.StopAtStepHook(last_step=steps)
621620
log_hook = tf.train.LoggingTensorHook(

modelzoo/dbmtl/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,9 @@ def train(sess_config,
527527
hooks = []
528528
hooks.extend(input_hooks)
529529

530-
sharded_saver = tf_config != None
531530
scaffold = tf.train.Scaffold(
532531
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
533-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
532+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
534533

535534
stop_hook = tf.train.StopAtStepHook(last_step=steps)
536535
log_hook = tf.train.LoggingTensorHook(

modelzoo/dcn/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,9 @@ def train(sess_config,
594594
hooks = []
595595
hooks.extend(input_hooks)
596596

597-
sharded_saver = tf_config != None
598597
scaffold = tf.train.Scaffold(
599598
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
600-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
599+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
601600

602601
stop_hook = tf.train.StopAtStepHook(last_step=steps)
603602
log_hook = tf.train.LoggingTensorHook(

modelzoo/dcnv2/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,9 @@ def train(sess_config,
610610
hooks = []
611611
hooks.extend(input_hooks)
612612

613-
sharded_saver = tf_config != None
614613
scaffold = tf.train.Scaffold(
615614
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
616-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
615+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
617616

618617
stop_hook = tf.train.StopAtStepHook(last_step=steps)
619618
log_hook = tf.train.LoggingTensorHook(

modelzoo/deepfm/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,9 @@ def train(sess_config,
472472
hooks = []
473473
hooks.extend(input_hooks)
474474

475-
sharded_saver = tf_config != None
476475
scaffold = tf.train.Scaffold(
477476
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
478-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
477+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
479478

480479
stop_hook = tf.train.StopAtStepHook(last_step=steps)
481480
log_hook = tf.train.LoggingTensorHook(

modelzoo/dien/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -776,10 +776,9 @@ def train(sess_config,
776776
hooks = []
777777
hooks.extend(input_hooks)
778778

779-
sharded_saver = tf_config != None
780779
scaffold = tf.train.Scaffold(
781780
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
782-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
781+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
783782

784783
stop_hook = tf.train.StopAtStepHook(last_step=steps)
785784
log_hook = tf.train.LoggingTensorHook(

modelzoo/din/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,9 @@ def train(sess_config,
594594
hooks = []
595595
hooks.extend(input_hooks)
596596

597-
sharded_saver = tf_config != None
598597
scaffold = tf.train.Scaffold(
599598
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
600-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
599+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
601600

602601
stop_hook = tf.train.StopAtStepHook(last_step=steps)
603602
log_hook = tf.train.LoggingTensorHook(

modelzoo/dlrm/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,9 @@ def train(sess_config,
507507
hooks = []
508508
hooks.extend(input_hooks)
509509

510-
sharded_saver = tf_config != None
511510
scaffold = tf.train.Scaffold(
512511
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
513-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
512+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
514513

515514
stop_hook = tf.train.StopAtStepHook(last_step=steps)
516515
log_hook = tf.train.LoggingTensorHook(

modelzoo/dssm/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,9 @@ def train(sess_config,
478478
hooks = []
479479
hooks.extend(input_hooks)
480480

481-
sharded_saver = tf_config != None
482481
scaffold = tf.train.Scaffold(
483482
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
484-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
483+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
485484

486485
stop_hook = tf.train.StopAtStepHook(last_step=steps)
487486
log_hook = tf.train.LoggingTensorHook(

modelzoo/esmm/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -534,10 +534,9 @@ def train(sess_config,
534534
hooks = []
535535
hooks.extend(input_hooks)
536536

537-
sharded_saver = tf_config != None
538537
scaffold = tf.train.Scaffold(
539538
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
540-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
539+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
541540

542541
stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
543542
log_hook = tf.train.LoggingTensorHook(

modelzoo/masknet/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,9 @@ def train(sess_config,
529529
hooks = []
530530
hooks.extend(input_hooks)
531531

532-
sharded_saver = tf_config != None
533532
scaffold = tf.train.Scaffold(
534533
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
535-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
534+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
536535

537536
stop_hook = tf.train.StopAtStepHook(last_step=steps)
538537
log_hook = tf.train.LoggingTensorHook(

modelzoo/mlperf/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,9 @@ def train(sess_config,
522522
hooks = []
523523
hooks.extend(input_hooks)
524524

525-
sharded_saver = tf_config != None
526525
scaffold = tf.train.Scaffold(
527526
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
528-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
527+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
529528

530529
stop_hook = tf.train.StopAtStepHook(last_step=steps)
531530
log_hook = tf.train.LoggingTensorHook(

modelzoo/mmoe/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,9 @@ def train(sess_config,
523523
hooks = []
524524
hooks.extend(input_hooks)
525525

526-
sharded_saver = tf_config != None
527526
scaffold = tf.train.Scaffold(
528527
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
529-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
528+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
530529

531530
stop_hook = tf.train.StopAtStepHook(last_step=steps)
532531
log_hook = tf.train.LoggingTensorHook(

modelzoo/ple/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,9 @@ def train(sess_config,
592592
hooks = []
593593
hooks.extend(input_hooks)
594594

595-
sharded_saver = tf_config != None
596595
scaffold = tf.train.Scaffold(
597596
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
598-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
597+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
599598

600599
stop_hook = tf.train.StopAtStepHook(last_step=steps)
601600
log_hook = tf.train.LoggingTensorHook(

modelzoo/simple_multitask/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,9 @@ def train(sess_config,
427427
hooks = []
428428
hooks.extend(input_hooks)
429429

430-
sharded_saver = tf_config != None
431430
scaffold = tf.train.Scaffold(
432431
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
433-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
432+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
434433

435434
stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
436435
log_hook = tf.train.LoggingTensorHook(

modelzoo/wide_and_deep/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,9 @@ def train(sess_config,
543543
hooks = []
544544
hooks.extend(input_hooks)
545545

546-
sharded_saver = tf_config != None
547546
scaffold = tf.train.Scaffold(
548547
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
549-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
548+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
550549

551550
stop_hook = tf.train.StopAtStepHook(last_step=steps)
552551
log_hook = tf.train.LoggingTensorHook(

tensorflow/python/feature_column/feature_column_v2_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7527,7 +7527,7 @@ def testEmbeddingVariableForL2FeatureEviction(self):
75277527
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
75287528
g_v = opt.compute_gradients(loss)
75297529
train_op = opt.apply_gradients(g_v)
7530-
saver = saver_module.Saver()
7530+
saver = saver_module.Saver(sharded=True)
75317531
init = variables_lib.global_variables_initializer()
75327532
with self.test_session() as sess:
75337533
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))
@@ -7758,7 +7758,7 @@ def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum(self):
77587758
g_v = opt.compute_gradients(loss)
77597759
train_op = opt.apply_gradients(g_v)
77607760
init = variables_lib.global_variables_initializer()
7761-
saver = saver_module.Saver()
7761+
saver = saver_module.Saver(sharded=True)
77627762

77637763
@test_util.run_deprecated_v1
77647764
def testEmbeddingVariableForInt32ID(self):
@@ -7783,7 +7783,7 @@ def testEmbeddingVariableForInt32ID(self):
77837783
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
77847784
g_v = opt.compute_gradients(loss)
77857785
train_op = opt.apply_gradients(g_v)
7786-
saver = saver_module.Saver()
7786+
saver = saver_module.Saver(sharded=True)
77877787
init = variables_lib.global_variables_initializer()
77887788
with self.test_session() as sess:
77897789
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))

tensorflow/python/ops/embedding_variable_ops_gpu_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def testSaveV3(self):
748748
g_v = opt.compute_gradients(loss)
749749
train_op = opt.apply_gradients(g_v, global_step=gs)
750750
init = variables.global_variables_initializer()
751-
saver = saver = saver_module.Saver()
751+
saver = saver = saver_module.Saver(sharded=True)
752752
checkpoint_directory = self.get_temp_dir()
753753
model_path = os.path.join(checkpoint_directory, "model.ckpt")
754754
with self.test_session() as sess:
@@ -816,7 +816,7 @@ def testEmbeddingVariableSaveAndRestoreOptimzierStatesForMultiTierWithHbm(self):
816816
opt = adagrad.AdagradOptimizer(0.1)
817817
g_v = opt.compute_gradients(loss)
818818
train_op = opt.apply_gradients(g_v, gs)
819-
saver = saver_module.Saver()
819+
saver = saver_module.Saver(sharded=True)
820820
graph = ops.get_default_graph()
821821
with self.test_session(graph = graph) as sess:
822822
saver.restore(sess, os.path.join(checkpoint_directory, "model.ckpt-12345"))

0 commit comments

Comments
 (0)