@@ -129,22 +129,20 @@ def create_file_and_save_alive_counts(self, base_dir: Text,
129
129
global_step : int ) -> None :
130
130
"""Creates and updates files with alive counts.
131
131
132
- Creates the directory `{base_dir}/learned_structure/` and saves the current
133
- alive counts to:
134
- `{base_dir}/learned_structure/{ALIVE_FILENAME}_{global_step}`.
132
+ Creates the directory `{base_dir}` and saves the current alive counts to:
133
+ `{base_dir}/{ALIVE_FILENAME}_{global_step}`.
135
134
136
135
Args:
137
136
base_dir: where to export the alive counts.
138
137
global_step: current value of global step, used as a suffix in filename.
139
138
"""
140
139
current_filename = '%s_%s' % (ALIVE_FILENAME , global_step )
141
- directory = os .path .join (base_dir , 'learned_structure' )
142
140
try :
143
- tf .gfile .MakeDirs (directory )
141
+ tf .gfile .MakeDirs (base_dir )
144
142
except tf .errors .OpError :
145
143
# Probably already exists. If not, we'll see the error in the next line.
146
144
pass
147
- with tf .gfile .Open (os .path .join (directory , current_filename ), 'w' ) as f :
145
+ with tf .gfile .Open (os .path .join (base_dir , current_filename ), 'w' ) as f :
148
146
self .save_alive_counts (f ) # pytype: disable=wrong-arg-types
149
147
150
148
@@ -196,3 +194,29 @@ def _compute_alive_counts(
196
194
197
195
def format_structure (structure : Dict [Text , int ]) -> Text :
198
196
return json .dumps (structure , indent = 2 , sort_keys = True , default = str )
197
+
198
+
199
+ class StructureExporterHook (tf .train .SessionRunHook ):
200
+ """Estimator hook for StructureExporter.
201
+
202
+ Usage:
203
+ exporter = structure_exporter.StructureExporter(
204
+ network_regularizer.op_regularizer_manager)
205
+ structure_export_hook = structure_exporter.StructureExporterHook(
206
+ exporter, '/path/to/cns')
207
+ estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
208
+ ...,
209
+ training_hooks=[structure_export_hook])
210
+ """
211
+
212
+ def __init__ (self , exporter : StructureExporter , export_dir : Text ):
213
+ self ._export_dir = export_dir
214
+ self ._exporter = exporter
215
+
216
+ def end (self , session : tf .Session ):
217
+ global_step = session .run (tf .train .get_global_step ())
218
+ tf .logging .info ('Exporting structure at step %d' , global_step )
219
+ tensor_to_eval_dict = session .run (self ._exporter .tensors )
220
+ self ._exporter .populate_tensor_values (session .run (tensor_to_eval_dict ))
221
+ self ._exporter .create_file_and_save_alive_counts (self ._export_dir ,
222
+ global_step )
0 commit comments