@@ -1071,10 +1071,14 @@ def _build(self, checkpoint_path, build_save, build_restore):
1071
1071
# pylint: disable=protected-access
1072
1072
self ._var_list = variables ._all_saveable_objects ()
1073
1073
from tensorflow .python .ops import hash_table
1074
+ from tensorflow .python .ops import kv_variable_ops
1074
1075
if isinstance (self ._var_list , dict ):
1076
+ ev = {}
1075
1077
ht = {}
1076
1078
lst = {}
1077
1079
for name , x in self ._var_list .items ():
1080
+ if isinstance (x , kv_variable_ops .EmbeddingVariable ):
1081
+ ev [name ] = x
1078
1082
if isinstance (x , hash_table .HashTable ):
1079
1083
if x .hash_table not in ht :
1080
1084
ht [x .hash_table ] = [x ]
@@ -1084,15 +1088,20 @@ def _build(self, checkpoint_path, build_save, build_restore):
1084
1088
lst [name ] = BloomFilterSaveable (x )
1085
1089
else :
1086
1090
lst [name ] = x
1091
+ if len (ev ) != 0 and not self ._sharded :
1092
+ raise ValueError ("EmbeddingVariable can only use sharded saver" )
1087
1093
if len (ht ) != 0 and not self ._sharded :
1088
1094
raise ValueError ("HashTable can only use sharded saver" )
1089
1095
for x , y in ht .items ():
1090
1096
lst [x .name ] = HashTableSaveable (y )
1091
1097
self ._var_list = lst
1092
1098
else :
1099
+ ev = []
1093
1100
ht = {}
1094
1101
lst = []
1095
1102
for x in self ._var_list :
1103
+ if isinstance (x , kv_variable_ops .EmbeddingVariable ):
1104
+ ev .append (x )
1096
1105
if isinstance (x , hash_table .HashTable ):
1097
1106
if x .hash_table not in ht :
1098
1107
ht [x .hash_table ] = [x ]
@@ -1102,6 +1111,8 @@ def _build(self, checkpoint_path, build_save, build_restore):
1102
1111
lst .append (BloomFilterSaveable (x ))
1103
1112
else :
1104
1113
lst .append (x )
1114
+ if len (ev ) != 0 and not self ._sharded :
1115
+ raise ValueError ("EmbeddingVariable can only use sharded saver" )
1105
1116
if len (ht ) != 0 and not self ._sharded :
1106
1117
raise ValueError ("HashTable can only use sharded saver" )
1107
1118
for x , y in ht .items ():
0 commit comments