Skip to content

Commit 656e7a2

Browse files
author
Amit Patankar
authored
Merge pull request tensorflow#21425 from saeta/fix_tpu
Refactor dependencies so keras_support can be imported directly.
2 parents b7127e5 + 14b8b8b commit 656e7a2

File tree

6 files changed

+24
-10
lines changed

6 files changed

+24
-10
lines changed

tensorflow/contrib/BUILD

-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ py_library(
107107
"//tensorflow/contrib/tfprof",
108108
"//tensorflow/contrib/timeseries",
109109
"//tensorflow/contrib/tpu",
110-
"//tensorflow/contrib/tpu:tpu_py",
111110
"//tensorflow/contrib/training:training_py",
112111
"//tensorflow/contrib/util:util_py",
113112
"//tensorflow/python:util",

tensorflow/contrib/cmake/python_protos.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ tensorflow/core
22
tensorflow/core/kernels/boosted_trees
33
tensorflow/core/profiler
44
tensorflow/python
5+
tensorflow/compiler/xla
56
tensorflow/contrib/boosted_trees/proto
67
tensorflow/contrib/cloud/kernels
78
tensorflow/contrib/decision_trees/proto

tensorflow/contrib/distribute/python/BUILD

+1-2
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,7 @@ py_library(
272272
deps = [
273273
":one_device_strategy",
274274
":values",
275-
"//tensorflow/contrib/tpu",
276-
"//tensorflow/contrib/tpu:tpu_py",
275+
"//tensorflow/contrib/tpu:tpu_lib",
277276
"//tensorflow/python:constant_op",
278277
"//tensorflow/python:control_flow_ops",
279278
"//tensorflow/python:framework_ops",

tensorflow/contrib/tpu/BUILD

+12-5
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ py_library(
4646
srcs_version = "PY2AND3",
4747
deps = [
4848
":tpu_lib",
49-
":tpu_py",
49+
"//tensorflow/compiler/xla/experimental/xla_sharding",
50+
"//tensorflow/compiler/xla/python_api:xla_shape",
5051
"//tensorflow/contrib/training:training_py",
5152
"//tensorflow/core:protos_all_py",
5253
"//tensorflow/python:array_ops",
@@ -133,7 +134,7 @@ py_library(
133134

134135
tf_custom_op_py_library(
135136
name = "tpu_py",
136-
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
137+
srcs = glob(["python/ops/*.py"]),
137138
dso = [":python/ops/_tpu_ops.so"],
138139
kernels = [
139140
":all_ops",
@@ -152,9 +153,13 @@ tf_custom_op_py_library(
152153

153154
py_library(
154155
name = "tpu",
155-
srcs = ["python/tpu/__init__.py"],
156+
srcs = [
157+
"__init__.py",
158+
"python/tpu/__init__.py",
159+
],
156160
srcs_version = "PY2AND3",
157161
deps = [
162+
":keras_support", # split out to avoid cycle with tpu_strategy
158163
":tpu_estimator",
159164
":tpu_lib",
160165
],
@@ -166,11 +171,13 @@ py_library(
166171
"python/tpu/keras_support.py",
167172
],
168173
srcs_version = "PY2AND3",
174+
visibility = [
175+
"//tensorflow:__subpackages__",
176+
],
169177
deps = [
170178
":tpu_lib",
171-
":tpu_py",
172179
"//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
173-
"//tensorflow/contrib/distribute/python:tpu_strategy",
180+
"//tensorflow/contrib/distribute",
174181
"//tensorflow/contrib/framework:framework_py",
175182
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
176183
"//tensorflow/core:protos_all_py",

tensorflow/contrib/tpu/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
@@TPUConfig
4848
4949
@@bfloat16_scope
50+
51+
@@TPUDistributionStrategy
52+
@@keras_to_tpu_model
5053
"""
5154

5255
from __future__ import absolute_import
@@ -58,6 +61,8 @@
5861
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
5962
from tensorflow.contrib.tpu.python.tpu.bfloat16 import *
6063
from tensorflow.contrib.tpu.python.tpu.device_assignment import *
64+
from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model
65+
from tensorflow.contrib.tpu.python.tpu.keras_support import TPUDistributionStrategy
6166
from tensorflow.contrib.tpu.python.tpu.topology import *
6267
from tensorflow.contrib.tpu.python.tpu.tpu import *
6368
from tensorflow.contrib.tpu.python.tpu.tpu_config import *

tensorflow/contrib/tpu/python/tpu/keras_support.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
import numpy as np
5656

5757
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver
58-
from tensorflow.contrib.distribute.python import tpu_strategy
5958
from tensorflow.contrib.framework.python.framework import experimental
6059
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
6160
from tensorflow.contrib.tpu.python.ops import tpu_ops
@@ -82,7 +81,11 @@
8281
from tensorflow.python.ops import variable_scope
8382
from tensorflow.python.platform import tf_logging as logging
8483

85-
TPUDistributionStrategy = tpu_strategy.TPUStrategy # pylint: disable=invalid-name
84+
85+
# Work-around dependency cycle between DistributionStrategy and TPU lib.
86+
def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name
87+
from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
88+
return tpu_strategy.TPUStrategy(*args, **kw)
8689

8790

8891
class TPUEmbedding(embeddings.Embedding):

0 commit comments

Comments
 (0)