|
32 | 32 | from tensor2tensor.utils import mlperf_log
|
33 | 33 |
|
34 | 34 | import tensorflow as tf
|
35 |
| -import tf_slim as slim |
36 |
| -from tensorflow.contrib.tpu.python.tpu import tpu_config |
| 35 | +# pylint: disable=g-import-not-at-top |
| 36 | +try: |
| 37 | + from tensorflow.contrib.tpu.python.tpu import tpu_config |
| 38 | +except ImportError: |
| 39 | + # TF 2.0 doesn't ship with contrib. |
| 40 | + tpu_config = None |
| 41 | +# pylint: enable=g-import-not-at-top |
37 | 42 |
|
38 | 43 |
|
39 | 44 |
|
@@ -199,7 +204,7 @@ class Problem(object):
|
199 | 204 | - Mutate defaults as needed
|
200 | 205 | * example_reading_spec
|
201 | 206 | - Specify the names and types of the features on disk.
|
202 |
| - - Specify slim.tfexample_decoder |
| 207 | + - Specify tf.contrib.slim.tfexample_decoder |
203 | 208 | * preprocess_example(example, mode, hparams)
|
204 | 209 | - Preprocess the example feature dict from feature name to Tensor or
|
205 | 210 | SparseTensor.
|
@@ -643,7 +648,7 @@ def dataset(self,
|
643 | 648 |
|
644 | 649 | data_filepattern = self.filepattern(data_dir, dataset_split, shard=shard)
|
645 | 650 | tf.logging.info("Reading data files from %s", data_filepattern)
|
646 |
| - data_files = sorted(slim.parallel_reader.get_data_files( |
| 651 | + data_files = sorted(tf.contrib.slim.parallel_reader.get_data_files( |
647 | 652 | data_filepattern))
|
648 | 653 |
|
649 | 654 | # Functions used in dataset transforms below. `filenames` can be either a
|
@@ -706,12 +711,12 @@ def decode_example(self, serialized_example):
|
706 | 711 | data_fields["batch_prediction_key"] = tf.FixedLenFeature([1], tf.int64, 0)
|
707 | 712 | if data_items_to_decoders is None:
|
708 | 713 | data_items_to_decoders = {
|
709 |
| - field: slim.tfexample_decoder.Tensor(field) |
| 714 | + field: tf.contrib.slim.tfexample_decoder.Tensor(field) |
710 | 715 | for field in data_fields
|
711 | 716 | }
|
712 | 717 |
|
713 |
| - decoder = slim.tfexample_decoder.TFExampleDecoder(data_fields, |
714 |
| - data_items_to_decoders) |
| 718 | + decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder( |
| 719 | + data_fields, data_items_to_decoders) |
715 | 720 |
|
716 | 721 | decode_items = list(sorted(data_items_to_decoders))
|
717 | 722 | decoded = decoder.decode(serialized_example, items=decode_items)
|
|
0 commit comments