Skip to content

Commit 8e00552

Browse files
committed
Added fail_if_exists as an argument for write_to_variables.
PiperOrigin-RevId: 277081924
1 parent 7c7b409 commit 8e00552

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

Diff for: morph_net/framework/tpu_util.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,21 @@ def maybe_convert_to_variable(tensor):
9494
var_store = {}
9595

9696

97-
def write_to_variable(tensor):
97+
def write_to_variable(tensor, fail_if_exists=True):
9898
"""Saves a tensor for later retrieval on CPU."""
9999
# Only relevant for debugging.
100100
debug_name = 'tpu_util__' + tensor.name.split(':')[0].split('/')[-1]
101101

102-
# Note: reuse cannot be changed from True to False, so we just check if
103-
# the variable exists.
104-
with tf.variable_scope('', reuse=True):
105-
try:
106-
tf.get_variable(debug_name)
107-
except ValueError:
108-
pass # Variable with name=debug_name does not exist; proceed.
109-
else:
110-
raise ValueError('Variable %s already exists!' % debug_name)
102+
if fail_if_exists:
103+
# Note: reuse cannot be changed from True to False, so we just check if
104+
# the variable exists.
105+
with tf.variable_scope('', reuse=True):
106+
try:
107+
tf.get_variable(debug_name)
108+
except ValueError:
109+
pass # Variable with name=debug_name does not exist; proceed.
110+
else:
111+
raise ValueError('Variable %s already exists!' % debug_name)
111112

112113
with tf.variable_scope('', reuse=tf.compat.v1.AUTO_REUSE):
113114
variable = tf.get_variable(

0 commit comments

Comments
 (0)