@@ -106,12 +106,35 @@ def end(self, session):
106
106
107
107
108
108
class TPUEstimator (Estimator , tf .contrib .tpu .TPUEstimator ):
109
- """An adanet.Estimator capable of running on TPU.
110
-
111
- If running on TPU, all summary calls are rewired to be no-ops during training.
112
-
113
- WARNING: this API is highly experimental, unstable, and can change without
114
- warning.
109
+ """An :class: `adanet.Estimator` capable of training and evaluating on TPU.
110
+
111
+ Note: Unless :code: `use_tpu=False`, training will run on TPU. However,
112
+ certain parts of AdaNet training loop, such as report materialization and best
113
+ candidate selection still occurr on CPU. Furthermore, inference also occurs on
114
+ CPU.
115
+
116
+ Args:
117
+ head: See :class:`adanet.Estimator`.
118
+ subnetwork_generator: See :class:`adanet.Estimator`.
119
+ max_iteration_steps: See :class:`adanet.Estimator`.
120
+ ensemblers: See :class:`adanet.Estimator`.
121
+ ensemble_strategies: See :class:`adanet.Estimator`.
122
+ evaluator: See :class:`adanet.Estimator`.
123
+ report_materializer: See :class:`adanet.Estimator`.
124
+ metric_fn: See :class:`adanet.Estimator`.
125
+ force_grow: See :class:`adanet.Estimator`.
126
+ replicate_ensemble_in_training: See :class:`adanet.Estimator`.
127
+ adanet_loss_decay: See :class:`adanet.Estimator`.
128
+ report_dir: See :class:`adanet.Estimator`.
129
+ config: See :class:`adanet.Estimator`.
130
+ use_tpu: Boolean to enable *both* training and evaluating on TPU. Defaults
131
+ to :code: `True` and is only provided to allow debugging models on
132
+ CPU/GPU. Use :class: `adanet.Estimator` instead if you do not plan to run
133
+ on TPU.
134
+ train_batch_size: See :class:`tf.contrib.tpu.TPUEstimator`.
135
+ eval_batch_size: See :class:`tf.contrib.tpu.TPUEstimator`.
136
+ debug: See :class:`adanet.Estimator`.
137
+ **kwargs: Extra keyword args passed to the parent.
115
138
"""
116
139
117
140
def __init__ (self ,
@@ -126,13 +149,13 @@ def __init__(self,
126
149
force_grow = False ,
127
150
replicate_ensemble_in_training = False ,
128
151
adanet_loss_decay = .9 ,
129
- worker_wait_timeout_secs = 7200 ,
130
152
model_dir = None ,
131
153
report_dir = None ,
132
154
config = None ,
133
155
use_tpu = True ,
134
156
train_batch_size = None ,
135
157
eval_batch_size = None ,
158
+ debug = False ,
136
159
** kwargs ):
137
160
self ._use_tpu = use_tpu
138
161
if not self ._use_tpu :
@@ -157,7 +180,6 @@ def __init__(self,
157
180
force_grow = force_grow ,
158
181
replicate_ensemble_in_training = replicate_ensemble_in_training ,
159
182
adanet_loss_decay = adanet_loss_decay ,
160
- worker_wait_timeout_secs = worker_wait_timeout_secs ,
161
183
model_dir = model_dir ,
162
184
report_dir = report_dir ,
163
185
config = config if config else tf .contrib .tpu .RunConfig (),
@@ -168,12 +190,14 @@ def __init__(self,
168
190
eval_batch_size = self ._eval_batch_size ,
169
191
** kwargs )
170
192
193
+ # Yields predictions on CPU even when use_tpu=True.
171
194
def predict (self ,
172
195
input_fn ,
173
196
predict_keys = None ,
174
197
hooks = None ,
175
198
checkpoint_path = None ,
176
199
yield_single_examples = True ):
200
+
177
201
# TODO: Required to support predict on CPU for TPUEstimator.
178
202
# This is the recommended method from TensorFlow TPUEstimator docs:
179
203
# https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimator#current_limitations
0 commit comments