File tree Expand file tree Collapse file tree 1 file changed +10
-1
lines changed
Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -58,6 +58,8 @@ def maybe_convert_to_variable(tensor):
5858 return the original input tensor.
5959 """
6060 op = tensor .op
61+ if is_on_cpu () and tensor in var_store :
62+ return var_store [tensor ]
6163 if op .type != 'ReadVariableOp' :
6264 # No need to convert.
6365 return tensor
@@ -102,7 +104,9 @@ def write_to_variable(tensor):
102104 use_resource = True )
103105 var_store [tensor ] = variable
104106 with tf .control_dependencies ([variable .assign (tensor )]):
105- return tf .identity (tensor )
107+ tensor_copy = tf .identity (tensor )
108+ var_store [tensor_copy ] = variable
109+ return tensor_copy
106110
107111
108112def read_from_variable (tensor ):
@@ -113,3 +117,8 @@ def read_from_variable(tensor):
113117 else :
114118 # Current read, but only works on TPU.
115119 return tensor
120+
121+
122+ def is_intermediate_var (v ):
123+ """Returns True if `v` was created by `write_to_variable` above."""
124+ return v in var_store .values ()
You can’t perform that action at this time.
0 commit comments