Skip to content

Commit 35282f6

Browse files
pkchmn-robot
authored andcommitted
Create Estimator hook for StructureExporter.
PiperOrigin-RevId: 277371977
1 parent 8e00552 commit 35282f6

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

Diff for: README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ with tf.Session() as sess:
199199
_, structure_exporter_tensors = sess.run([train_op, exporter.tensors])
200200
if (step % 1000 == 0):
201201
exporter.populate_tensor_values(structure_exporter_tensors)
202-
exporter.create_file_and_save_alive_counts(train_dir, step)
202+
exporter.create_file_and_save_alive_counts(
203+
os.path.join(train_dir, 'learned_structure'), step)
203204
```
204205

205206
## Misc

Diff for: morph_net/tools/structure_exporter.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -129,22 +129,20 @@ def create_file_and_save_alive_counts(self, base_dir: Text,
129129
global_step: int) -> None:
130130
"""Creates and updates files with alive counts.
131131
132-
Creates the directory `{base_dir}/learned_structure/` and saves the current
133-
alive counts to:
134-
`{base_dir}/learned_structure/{ALIVE_FILENAME}_{global_step}`.
132+
Creates the directory `{base_dir}` and saves the current alive counts to:
133+
`{base_dir}/{ALIVE_FILENAME}_{global_step}`.
135134
136135
Args:
137136
base_dir: where to export the alive counts.
138137
global_step: current value of global step, used as a suffix in filename.
139138
"""
140139
current_filename = '%s_%s' % (ALIVE_FILENAME, global_step)
141-
directory = os.path.join(base_dir, 'learned_structure')
142140
try:
143-
tf.gfile.MakeDirs(directory)
141+
tf.gfile.MakeDirs(base_dir)
144142
except tf.errors.OpError:
145143
# Probably already exists. If not, we'll see the error in the next line.
146144
pass
147-
with tf.gfile.Open(os.path.join(directory, current_filename), 'w') as f:
145+
with tf.gfile.Open(os.path.join(base_dir, current_filename), 'w') as f:
148146
self.save_alive_counts(f) # pytype: disable=wrong-arg-types
149147

150148

@@ -196,3 +194,29 @@ def _compute_alive_counts(
196194

197195
def format_structure(structure: Dict[Text, int]) -> Text:
198196
return json.dumps(structure, indent=2, sort_keys=True, default=str)
197+
198+
199+
class StructureExporterHook(tf.train.SessionRunHook):
200+
"""Estimator hook for StructureExporter.
201+
202+
Usage:
203+
exporter = structure_exporter.StructureExporter(
204+
network_regularizer.op_regularizer_manager)
205+
structure_export_hook = structure_exporter.StructureExporterHook(
206+
exporter, '/path/to/cns')
207+
estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
208+
...,
209+
training_hooks=[structure_export_hook])
210+
"""
211+
212+
def __init__(self, exporter: StructureExporter, export_dir: Text):
213+
self._export_dir = export_dir
214+
self._exporter = exporter
215+
216+
def end(self, session: tf.Session):
217+
global_step = session.run(tf.train.get_global_step())
218+
tf.logging.info('Exporting structure at step %d', global_step)
219+
tensor_to_eval_dict = session.run(self._exporter.tensors)
220+
self._exporter.populate_tensor_values(session.run(tensor_to_eval_dict))
221+
self._exporter.create_file_and_save_alive_counts(self._export_dir,
222+
global_step)

Diff for: morph_net/tools/structure_exporter_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,15 @@ def test_create_file_and_save_alive_counts(self):
129129
base_dir = os.path.join(FLAGS.test_tmpdir, 'ee')
130130

131131
self.exporter.populate_tensor_values(self.tensor_value_1)
132-
self.exporter.create_file_and_save_alive_counts(base_dir, 19)
132+
self.exporter.create_file_and_save_alive_counts(
133+
os.path.join(base_dir, 'learned_structure'), 19)
133134
self.assertAllEqual(
134135
_alive_from_file('ee/learned_structure/alive_19'),
135136
self.expected_alive_1)
136137

137138
self.exporter.populate_tensor_values(self.tensor_value_2)
138-
self.exporter.create_file_and_save_alive_counts(base_dir, 1009)
139+
self.exporter.create_file_and_save_alive_counts(
140+
os.path.join(base_dir, 'learned_structure'), 1009)
139141
self.assertAllEqual(
140142
_alive_from_file('ee/learned_structure/alive_1009'),
141143
self.expected_alive_2)

0 commit comments

Comments
 (0)