Skip to content

Commit b49b0be

Browse files
Eugen Hotajcweill
Eugen Hotaj
authored andcommitted
Remove outdated doc about TPUEstimator not supporting summaries.
PiperOrigin-RevId: 240646161
1 parent ed25dc2 commit b49b0be

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

Diff for: adanet/core/tpu_estimator.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,35 @@ def end(self, session):
106106

107107

108108
class TPUEstimator(Estimator, tf.contrib.tpu.TPUEstimator):
109-
"""An adanet.Estimator capable of running on TPU.
110-
111-
If running on TPU, all summary calls are rewired to be no-ops during training.
112-
113-
WARNING: this API is highly experimental, unstable, and can change without
114-
warning.
109+
"""An :class: `adanet.Estimator` capable of training and evaluating on TPU.
110+
111+
Note: Unless :code: `use_tpu=False`, training will run on TPU. However,
112+
certain parts of AdaNet training loop, such as report materialization and best
113+
candidate selection still occurr on CPU. Furthermore, inference also occurs on
114+
CPU.
115+
116+
Args:
117+
head: See :class:`adanet.Estimator`.
118+
subnetwork_generator: See :class:`adanet.Estimator`.
119+
max_iteration_steps: See :class:`adanet.Estimator`.
120+
ensemblers: See :class:`adanet.Estimator`.
121+
ensemble_strategies: See :class:`adanet.Estimator`.
122+
evaluator: See :class:`adanet.Estimator`.
123+
report_materializer: See :class:`adanet.Estimator`.
124+
metric_fn: See :class:`adanet.Estimator`.
125+
force_grow: See :class:`adanet.Estimator`.
126+
replicate_ensemble_in_training: See :class:`adanet.Estimator`.
127+
adanet_loss_decay: See :class:`adanet.Estimator`.
128+
report_dir: See :class:`adanet.Estimator`.
129+
config: See :class:`adanet.Estimator`.
130+
use_tpu: Boolean to enable *both* training and evaluating on TPU. Defaults
131+
to :code: `True` and is only provided to allow debugging models on
132+
CPU/GPU. Use :class: `adanet.Estimator` instead if you do not plan to run
133+
on TPU.
134+
train_batch_size: See :class:`tf.contrib.tpu.TPUEstimator`.
135+
eval_batch_size: See :class:`tf.contrib.tpu.TPUEstimator`.
136+
debug: See :class:`adanet.Estimator`.
137+
**kwargs: Extra keyword args passed to the parent.
115138
"""
116139

117140
def __init__(self,
@@ -126,13 +149,13 @@ def __init__(self,
126149
force_grow=False,
127150
replicate_ensemble_in_training=False,
128151
adanet_loss_decay=.9,
129-
worker_wait_timeout_secs=7200,
130152
model_dir=None,
131153
report_dir=None,
132154
config=None,
133155
use_tpu=True,
134156
train_batch_size=None,
135157
eval_batch_size=None,
158+
debug=False,
136159
**kwargs):
137160
self._use_tpu = use_tpu
138161
if not self._use_tpu:
@@ -157,7 +180,6 @@ def __init__(self,
157180
force_grow=force_grow,
158181
replicate_ensemble_in_training=replicate_ensemble_in_training,
159182
adanet_loss_decay=adanet_loss_decay,
160-
worker_wait_timeout_secs=worker_wait_timeout_secs,
161183
model_dir=model_dir,
162184
report_dir=report_dir,
163185
config=config if config else tf.contrib.tpu.RunConfig(),
@@ -168,12 +190,14 @@ def __init__(self,
168190
eval_batch_size=self._eval_batch_size,
169191
**kwargs)
170192

193+
# Yields predictions on CPU even when use_tpu=True.
171194
def predict(self,
172195
input_fn,
173196
predict_keys=None,
174197
hooks=None,
175198
checkpoint_path=None,
176199
yield_single_examples=True):
200+
177201
# TODO: Required to support predict on CPU for TPUEstimator.
178202
# This is the recommended method from TensorFlow TPUEstimator docs:
179203
# https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimator#current_limitations

0 commit comments

Comments
 (0)