Skip to content

Commit 04413cf

Browse files
authored
[Embedding] Fix the issue of default_value type mismatch in the EV Gather op. (#989)
Signed-off-by: Lyaction <[email protected]>
1 parent a4489e3 commit 04413cf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorflow/python/ops/kv_variable_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -858,10 +858,10 @@ def sparse_read(self, indices, name=None, ev_init_value=None, counts=None):
858858
if self._trainable:
859859
tape.variable_accessed(self)
860860
if ev_init_value is not None:
861-
default_value = ev_init_value
861+
default_value = math_ops.cast(ev_init_value, self.dtype)
862862
is_use_default_value_tensor = True
863863
else:
864-
default_value = ops.convert_to_tensor(1.0)
864+
default_value = ops.convert_to_tensor(1.0, dtype=self.dtype)
865865
is_use_default_value_tensor = False
866866
if counts != None:
867867
value = gen_kv_variable_ops.kv_resource_gather_v1(self._handle,

0 commit comments

Comments
 (0)