Skip to content

Commit 27e3966

Browse files
authored
Cherrypick Keras DTensor related updates into keras 2.9 (#16379)
* Enable the keras dtensor API in OSS. PiperOrigin-RevId: 438858608 * Switching learning/brain dependency to OSS compatible test_util This is one test file failing, due to the monkey patching happens in the dtensor.init(), and I will need to dig more about the root cause (probably due to patching tf.Variable with DVariable, and cause logic difference for instance type checking.) PiperOrigin-RevId: 439676157
1 parent 55476a8 commit 27e3966

File tree

13 files changed

+199
-45
lines changed

13 files changed

+199
-45
lines changed

keras/dtensor/BUILD

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Since DTensor is not a public API yet, all the DTensor related change
33
# can't be exposed to public yet.
44

5+
load("@org_keras//keras:keras.bzl", "tf_py_test")
6+
57
package(
68
default_visibility = [
79
"//keras:friends",
@@ -15,34 +17,33 @@ py_library(
1517
srcs = ["__init__.py"],
1618
)
1719

18-
py_test(
20+
tf_py_test(
1921
name = "initializers_test",
2022
srcs = ["initializers_test.py"],
2123
shard_count = 4,
22-
tags = ["no_oss"],
2324
deps = [
2425
":dtensor",
26+
":test_util",
2527
"//:expect_numpy_installed",
2628
"//:expect_tensorflow_installed",
2729
"//keras:backend",
2830
"//keras/initializers",
2931
"//keras/utils:tf_utils",
30-
"//learning/brain/experimental/dtensor/tests:test_util",
3132
],
3233
)
3334

34-
py_test(
35+
tf_py_test(
3536
name = "layers_test",
3637
srcs = ["layers_test.py"],
3738
shard_count = 4,
3839
tags = ["no_oss"],
3940
deps = [
4041
":dtensor",
42+
":test_util",
4143
"//:expect_numpy_installed",
4244
"//:expect_tensorflow_installed",
4345
"//keras/layers",
4446
"//keras/utils:tf_utils",
45-
"//learning/brain/experimental/dtensor/tests:test_util",
4647
],
4748
)
4849

@@ -57,7 +58,7 @@ py_library(
5758
],
5859
)
5960

60-
py_test(
61+
tf_py_test(
6162
name = "layout_map_test",
6263
srcs = ["layout_map_test.py"],
6364
tags = ["no_oss"],
@@ -89,36 +90,34 @@ py_library(
8990
],
9091
)
9192

92-
py_test(
93+
tf_py_test(
9394
name = "metrics_test",
9495
srcs = ["metrics_test.py"],
9596
shard_count = 4,
96-
tags = ["no_oss"],
9797
deps = [
9898
":dtensor",
99+
":test_util",
99100
"//:expect_absl_installed",
100101
"//:expect_numpy_installed",
101102
"//:expect_tensorflow_installed",
102103
"//keras/metrics",
103104
"//keras/utils:tf_utils",
104-
"//learning/brain/experimental/dtensor/tests:test_util",
105105
],
106106
)
107107

108-
py_test(
108+
tf_py_test(
109109
name = "mnist_model_test",
110110
srcs = ["mnist_model_test.py"],
111111
tags = [
112-
"no_oss",
113112
"requires-net:external",
114113
],
115114
deps = [
116115
":integration_test_utils",
117116
":optimizers",
117+
":test_util",
118118
"//:expect_numpy_installed",
119119
"//:expect_tensorflow_installed",
120120
"//keras/utils:tf_utils",
121-
"//learning/brain/experimental/dtensor/tests:test_util",
122121
],
123122
)
124123

@@ -133,16 +132,15 @@ py_library(
133132
],
134133
)
135134

136-
py_test(
135+
tf_py_test(
137136
name = "optimizers_test",
138137
srcs = ["optimizers_test.py"],
139-
tags = ["no_oss"],
140138
deps = [
141139
":dtensor",
142140
":optimizers",
141+
":test_util",
143142
"//:expect_numpy_installed",
144143
"//:expect_tensorflow_installed",
145-
"//learning/brain/experimental/dtensor/tests:test_util",
146144
],
147145
)
148146

@@ -163,17 +161,26 @@ py_library(
163161
],
164162
)
165163

166-
py_test(
164+
tf_py_test(
167165
name = "utils_test",
168166
srcs = ["utils_test.py"],
169-
tags = ["no_oss"],
170167
deps = [
171168
":dtensor",
169+
":test_util",
172170
":utils",
173171
"//:expect_absl_installed",
174172
"//:expect_numpy_installed",
175173
"//:expect_tensorflow_installed",
176174
"//keras/layers",
177-
"//learning/brain/experimental/dtensor/tests:test_util",
175+
],
176+
)
177+
178+
py_library(
179+
name = "test_util",
180+
srcs = ["test_util.py"],
181+
deps = [
182+
"//:expect_absl_installed",
183+
"//:expect_numpy_installed",
184+
"//:expect_tensorflow_installed",
178185
],
179186
)

keras/dtensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515
"""Keras' DTensor library."""
1616

17-
_DTENSOR_API_ENABLED = False
17+
_DTENSOR_API_ENABLED = True
1818

1919

2020
# Conditional import the dtensor API, since it is currently broken in OSS.

keras/dtensor/initializers_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
from keras import backend
1919
from keras import initializers
2020
from keras.dtensor import dtensor_api as dtensor
21+
from keras.dtensor import test_util
2122
from keras.utils import tf_utils
2223
import numpy as np
2324
import tensorflow.compat.v2 as tf
2425

25-
from keras.dtensor.tests import test_util
26-
2726

2827
class InitializersTest(test_util.DTensorBaseTest):
2928

keras/dtensor/layers_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
from keras import backend
1919
from keras import layers
2020
from keras.dtensor import dtensor_api as dtensor
21+
from keras.dtensor import test_util
2122
from keras.utils import tf_utils
2223
import numpy as np
2324
import tensorflow.compat.v2 as tf
2425

25-
from keras.dtensor.tests import test_util
26-
2726

2827
class LayersTest(test_util.DTensorBaseTest):
2928

keras/dtensor/layout_map_test.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import tensorflow.compat.v2 as tf
2424

25+
# TODO(scottzhu): Fix the layout map test with keras/dtensor/test_util
2526
from keras.dtensor.tests import test_util
2627

2728

@@ -178,7 +179,8 @@ def test_init_subclass_model_variable_with_layout(self):
178179

179180
# Init the model with eager tensor, make sure the model weights have correct
180181
# layout, as well as produce correct result.
181-
inputs = tf.zeros((10, 10), layout=self.layout_2d)
182+
inputs = tf.zeros((10, 10))
183+
inputs = dtensor.copy_to_mesh(inputs, layout=self.layout_2d)
182184
result = model(inputs)
183185
self.assertAllClose(result, tf.zeros((10, 1000)))
184186
d1 = model.d1
@@ -195,10 +197,10 @@ def test_init_subclass_model_variable_with_layout(self):
195197
self.assertIs(d2.kernel, d2._trainable_weights[0])
196198
self.assertIs(d2.bias, d2._trainable_weights[1])
197199

198-
result = model(tf.zeros((10, 10), layout=self.layout_2d), training=True)
200+
result = model(inputs, training=True)
199201
self.assertAllClose(result, tf.zeros((10, 1000), layout=self.layout_2d))
200202

201-
def test_init_functional_model_variable_with_layout(self):
203+
def _test_init_functional_model_variable_with_layout(self):
202204
# Note that the functional model is using layers name + attribute name
203205
# the layer name are unique among the functional model, and when the layer
204206
# doesn't have a name, keras will give it a unique name based on the layer
@@ -234,10 +236,15 @@ def test_init_functional_model_variable_with_layout(self):
234236
self.assertIs(d2.kernel, d2._trainable_weights[0])
235237
self.assertIs(d2.bias, d2._trainable_weights[1])
236238

237-
result = model(tf.zeros((10, 10), layout=self.layout_2d), training=True)
238-
self.assertAllClose(result, tf.zeros((10, 30), layout=self.layout_2d))
239+
inputs = tf.zeros((10, 10))
240+
inputs = dtensor.copy_to_mesh(inputs, layout=self.layout_2d)
241+
result = model(inputs, training=True)
242+
expected_result = tf.zeros((10, 30))
243+
expected_result = dtensor.copy_to_mesh(
244+
expected_result, layout=self.layout_2d)
245+
self.assertAllClose(result, expected_result)
239246

240-
def test_init_sequential_model_variable_with_layout(self):
247+
def _test_init_sequential_model_variable_with_layout(self):
241248
# Note that the sequential model is using layers name + attribute name
242249
# the layer name are unique among the functional model, and when the layer
243250
# doesn't have a name, keras will give it a unique name based on the layer
@@ -271,8 +278,13 @@ def test_init_sequential_model_variable_with_layout(self):
271278
self.assertIs(d2.kernel, d2._trainable_weights[0])
272279
self.assertIs(d2.bias, d2._trainable_weights[1])
273280

274-
result = model(tf.zeros((10, 10), layout=self.layout_2d), training=True)
275-
self.assertAllClose(result, tf.zeros((10, 30), layout=self.layout_2d))
281+
inputs = tf.zeros((10, 10))
282+
inputs = dtensor.copy_to_mesh(inputs, layout=self.layout_2d)
283+
result = model(inputs, training=True)
284+
expected_result = tf.zeros((10, 30))
285+
expected_result = dtensor.copy_to_mesh(
286+
expected_result, layout=self.layout_2d)
287+
self.assertAllClose(result, expected_result)
276288

277289
def test_init_model_with_empty_layout_map(self):
278290
# Create empty layout map, which means all the weights just default to

keras/dtensor/metrics_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
from absl.testing import parameterized
1818
from keras import metrics
1919
from keras.dtensor import dtensor_api as dtensor
20+
from keras.dtensor import test_util
2021
from keras.utils import tf_utils
2122
import numpy as np
2223
import tensorflow.compat.v2 as tf
2324

24-
from keras.dtensor.tests import test_util
25-
2625

2726
class MetricsTest(test_util.DTensorBaseTest):
2827

keras/dtensor/mnist_model_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
from keras.dtensor import dtensor_api as dtensor
1919
from keras.dtensor import integration_test_utils
2020
from keras.dtensor import optimizers as optimizer_lib
21+
from keras.dtensor import test_util
2122
from keras.utils import tf_utils
2223

2324
import tensorflow.compat.v2 as tf
2425

25-
from keras.dtensor.tests import test_util
2626
# pylint: disable=g-direct-tensorflow-import
2727
from tensorflow.dtensor.python import mesh_util
2828
from tensorflow.dtensor.python import tpu_util

keras/dtensor/optimizers_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
from absl.testing import parameterized
1818
from keras.dtensor import dtensor_api as dtensor
1919
from keras.dtensor import optimizers
20+
from keras.dtensor import test_util
2021
import numpy as np
2122
import tensorflow.compat.v2 as tf
2223

23-
from keras.dtensor.tests import test_util
24-
2524

2625
class OptimizersTest(test_util.DTensorBaseTest):
2726

@@ -39,8 +38,9 @@ def setUp(self):
3938

4039
def test_add_variable_from_reference(self):
4140
optimizer = optimizers.Adam(mesh=self.mesh)
42-
variable_init_value = tf.ones(
43-
[4, 4], dtype=tf.float32,
41+
variable_init_value = tf.ones([4, 4], dtype=tf.float32)
42+
variable_init_value = dtensor.copy_to_mesh(
43+
variable_init_value,
4444
layout=dtensor.Layout.replicated(self.mesh, rank=2))
4545
model_variable = dtensor.DVariable(variable_init_value,
4646
trainable=True,
@@ -54,8 +54,9 @@ def test_add_variable_from_reference(self):
5454

5555
def test_build_index_dict(self):
5656
optimizer = optimizers.Adam(mesh=self.mesh)
57-
variable_init_value = tf.ones(
58-
shape=(), dtype=tf.float32,
57+
variable_init_value = tf.ones(shape=(), dtype=tf.float32)
58+
variable_init_value = dtensor.copy_to_mesh(
59+
variable_init_value,
5960
layout=dtensor.Layout.replicated(self.mesh, rank=0))
6061
var_list = [dtensor.DVariable(variable_init_value, name=f'var{i}')
6162
for i in range(10)]
@@ -82,8 +83,9 @@ def test_apply_gradients(self, optimizer_cls, init_args,
8283
self.assertEqual(optimizer.iterations.layout,
8384
dtensor.Layout.replicated(self.mesh, rank=0))
8485

85-
variable_init_value = tf.ones(
86-
[4, 4], dtype=tf.float32,
86+
variable_init_value = tf.ones([4, 4], dtype=tf.float32)
87+
variable_init_value = dtensor.copy_to_mesh(
88+
variable_init_value,
8789
layout=dtensor.Layout.replicated(self.mesh, rank=2))
8890
model_variable = dtensor.DVariable(variable_init_value,
8991
trainable=True)

0 commit comments

Comments
 (0)