Skip to content

Commit afdf80e

Browse files
committed
[Embedding] Check the sharded property of tf.train.Saver.
Signed-off-by: chenbangduo.cbd <[email protected]>
1 parent 93c69ad commit afdf80e

File tree

18 files changed

+33
-32
lines changed

18 files changed

+33
-32
lines changed

modelzoo/bst/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -612,10 +612,9 @@ def train(sess_config,
612612
hooks = []
613613
hooks.extend(input_hooks)
614614

615-
sharded_saver = tf_config != None
616615
scaffold = tf.train.Scaffold(
617616
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
618-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
617+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
619618

620619
stop_hook = tf.train.StopAtStepHook(last_step=steps)
621620
log_hook = tf.train.LoggingTensorHook(

modelzoo/dbmtl/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,9 @@ def train(sess_config,
527527
hooks = []
528528
hooks.extend(input_hooks)
529529

530-
sharded_saver = tf_config != None
531530
scaffold = tf.train.Scaffold(
532531
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
533-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
532+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
534533

535534
stop_hook = tf.train.StopAtStepHook(last_step=steps)
536535
log_hook = tf.train.LoggingTensorHook(

modelzoo/dcn/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,9 @@ def train(sess_config,
594594
hooks = []
595595
hooks.extend(input_hooks)
596596

597-
sharded_saver = tf_config != None
598597
scaffold = tf.train.Scaffold(
599598
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
600-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
599+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
601600

602601
stop_hook = tf.train.StopAtStepHook(last_step=steps)
603602
log_hook = tf.train.LoggingTensorHook(

modelzoo/dcnv2/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,9 @@ def train(sess_config,
610610
hooks = []
611611
hooks.extend(input_hooks)
612612

613-
sharded_saver = tf_config != None
614613
scaffold = tf.train.Scaffold(
615614
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
616-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
615+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
617616

618617
stop_hook = tf.train.StopAtStepHook(last_step=steps)
619618
log_hook = tf.train.LoggingTensorHook(

modelzoo/deepfm/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,9 @@ def train(sess_config,
472472
hooks = []
473473
hooks.extend(input_hooks)
474474

475-
sharded_saver = tf_config != None
476475
scaffold = tf.train.Scaffold(
477476
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
478-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
477+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
479478

480479
stop_hook = tf.train.StopAtStepHook(last_step=steps)
481480
log_hook = tf.train.LoggingTensorHook(

modelzoo/dien/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -776,10 +776,9 @@ def train(sess_config,
776776
hooks = []
777777
hooks.extend(input_hooks)
778778

779-
sharded_saver = tf_config != None
780779
scaffold = tf.train.Scaffold(
781780
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
782-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
781+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
783782

784783
stop_hook = tf.train.StopAtStepHook(last_step=steps)
785784
log_hook = tf.train.LoggingTensorHook(

modelzoo/din/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,9 @@ def train(sess_config,
594594
hooks = []
595595
hooks.extend(input_hooks)
596596

597-
sharded_saver = tf_config != None
598597
scaffold = tf.train.Scaffold(
599598
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
600-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
599+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
601600

602601
stop_hook = tf.train.StopAtStepHook(last_step=steps)
603602
log_hook = tf.train.LoggingTensorHook(

modelzoo/dlrm/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,9 @@ def train(sess_config,
507507
hooks = []
508508
hooks.extend(input_hooks)
509509

510-
sharded_saver = tf_config != None
511510
scaffold = tf.train.Scaffold(
512511
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
513-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
512+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
514513

515514
stop_hook = tf.train.StopAtStepHook(last_step=steps)
516515
log_hook = tf.train.LoggingTensorHook(

modelzoo/dssm/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,9 @@ def train(sess_config,
478478
hooks = []
479479
hooks.extend(input_hooks)
480480

481-
sharded_saver = tf_config != None
482481
scaffold = tf.train.Scaffold(
483482
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
484-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
483+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
485484

486485
stop_hook = tf.train.StopAtStepHook(last_step=steps)
487486
log_hook = tf.train.LoggingTensorHook(

modelzoo/esmm/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -534,10 +534,9 @@ def train(sess_config,
534534
hooks = []
535535
hooks.extend(input_hooks)
536536

537-
sharded_saver = tf_config != None
538537
scaffold = tf.train.Scaffold(
539538
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
540-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
539+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
541540

542541
stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
543542
log_hook = tf.train.LoggingTensorHook(

modelzoo/masknet/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,9 @@ def train(sess_config,
529529
hooks = []
530530
hooks.extend(input_hooks)
531531

532-
sharded_saver = tf_config != None
533532
scaffold = tf.train.Scaffold(
534533
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
535-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
534+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
536535

537536
stop_hook = tf.train.StopAtStepHook(last_step=steps)
538537
log_hook = tf.train.LoggingTensorHook(

modelzoo/mlperf/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,9 @@ def train(sess_config,
522522
hooks = []
523523
hooks.extend(input_hooks)
524524

525-
sharded_saver = tf_config != None
526525
scaffold = tf.train.Scaffold(
527526
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
528-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
527+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
529528

530529
stop_hook = tf.train.StopAtStepHook(last_step=steps)
531530
log_hook = tf.train.LoggingTensorHook(

modelzoo/mmoe/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,9 @@ def train(sess_config,
523523
hooks = []
524524
hooks.extend(input_hooks)
525525

526-
sharded_saver = tf_config != None
527526
scaffold = tf.train.Scaffold(
528527
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
529-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
528+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
530529

531530
stop_hook = tf.train.StopAtStepHook(last_step=steps)
532531
log_hook = tf.train.LoggingTensorHook(

modelzoo/ple/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,9 @@ def train(sess_config,
592592
hooks = []
593593
hooks.extend(input_hooks)
594594

595-
sharded_saver = tf_config != None
596595
scaffold = tf.train.Scaffold(
597596
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
598-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
597+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
599598

600599
stop_hook = tf.train.StopAtStepHook(last_step=steps)
601600
log_hook = tf.train.LoggingTensorHook(

modelzoo/simple_multitask/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,9 @@ def train(sess_config,
427427
hooks = []
428428
hooks.extend(input_hooks)
429429

430-
sharded_saver = tf_config != None
431430
scaffold = tf.train.Scaffold(
432431
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
433-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
432+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
434433

435434
stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
436435
log_hook = tf.train.LoggingTensorHook(

modelzoo/wide_and_deep/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,9 @@ def train(sess_config,
543543
hooks = []
544544
hooks.extend(input_hooks)
545545

546-
sharded_saver = tf_config != None
547546
scaffold = tf.train.Scaffold(
548547
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
549-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
548+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))
550549

551550
stop_hook = tf.train.StopAtStepHook(last_step=steps)
552551
log_hook = tf.train.LoggingTensorHook(

tensorflow/python/training/saver.py

+11
Original file line numberDiff line numberDiff line change
@@ -1071,10 +1071,14 @@ def _build(self, checkpoint_path, build_save, build_restore):
10711071
# pylint: disable=protected-access
10721072
self._var_list = variables._all_saveable_objects()
10731073
from tensorflow.python.ops import hash_table
1074+
from tensorflow.python.ops import kv_variable_ops
10741075
if isinstance(self._var_list, dict):
1076+
ev = {}
10751077
ht = {}
10761078
lst = {}
10771079
for name, x in self._var_list.items():
1080+
if isinstance(x, kv_variable_ops.EmbeddingVariable):
1081+
ev[name] = x
10781082
if isinstance(x, hash_table.HashTable):
10791083
if x.hash_table not in ht:
10801084
ht[x.hash_table] = [x]
@@ -1084,15 +1088,20 @@ def _build(self, checkpoint_path, build_save, build_restore):
10841088
lst[name] = BloomFilterSaveable(x)
10851089
else:
10861090
lst[name] = x
1091+
if len(ev) != 0 and not self._sharded:
1092+
raise ValueError("EmbeddingVariable can only use sharded saver")
10871093
if len(ht) != 0 and not self._sharded:
10881094
raise ValueError("HashTable can only use sharded saver")
10891095
for x, y in ht.items():
10901096
lst[x.name] = HashTableSaveable(y)
10911097
self._var_list = lst
10921098
else:
1099+
ev = []
10931100
ht = {}
10941101
lst = []
10951102
for x in self._var_list:
1103+
if isinstance(x, kv_variable_ops.EmbeddingVariable):
1104+
ev.append(x)
10961105
if isinstance(x, hash_table.HashTable):
10971106
if x.hash_table not in ht:
10981107
ht[x.hash_table] = [x]
@@ -1102,6 +1111,8 @@ def _build(self, checkpoint_path, build_save, build_restore):
11021111
lst.append(BloomFilterSaveable(x))
11031112
else:
11041113
lst.append(x)
1114+
if len(ev) != 0 and not self._sharded:
1115+
raise ValueError("EmbeddingVariable can only use sharded saver")
11051116
if len(ht) != 0 and not self._sharded:
11061117
raise ValueError("HashTable can only use sharded saver")
11071118
for x, y in ht.items():

tensorflow/python/training/saver_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,12 @@ def _model():
852852
for orig, restored in zip(orig_vals, restored_vals):
853853
self.assertAllEqual(orig, restored)
854854

855+
def testEnableSaverShardedWhenUseEmbeddingVariable(self):
856+
with ops_lib.Graph().as_default():
857+
emb_var = \
858+
variable_scope.get_embedding_variable(name="emb_var", embedding_dim=64)
859+
with self.assertRaisesRegexp(ValueError, "EmbeddingVariable"):
860+
saver_module.Saver([emb_var], sharded=False)
855861

856862
class SaveRestoreShardedTest(test.TestCase):
857863

0 commit comments

Comments
 (0)