Skip to content

Commit 40d179c

Browse files
BrianWiedertensorflower-gardener
authored andcommitted
Split applications_test into two tests, one for channel_first and another for channel_last to reduce test time, and flakiness.
PiperOrigin-RevId: 598908517
1 parent 6f9283e commit 40d179c

File tree

2 files changed

+45
-26
lines changed

2 files changed

+45
-26
lines changed

tf_keras/applications/BUILD

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,30 @@ py_library(
5252
)
5353

5454
tf_py_test(
55-
name = "applications_test",
56-
size = "medium",
55+
name = "applications_test_channels_first",
56+
srcs = ["applications_test.py"],
57+
args = ["--image_data_format=channels_first"],
58+
main = "applications_test.py",
59+
shard_count = 50,
60+
tags = [
61+
"no_oss", # b/318174391
62+
"no_rocm",
63+
"notsan", # b/168814536
64+
"requires-net:external",
65+
],
66+
deps = [
67+
":applications",
68+
"//:expect_absl_installed", # absl/testing:parameterized
69+
"//:expect_tensorflow_installed",
70+
"//tf_keras/testing_infra:test_combinations",
71+
],
72+
)
73+
74+
tf_py_test(
75+
name = "applications_test_channels_last",
5776
srcs = ["applications_test.py"],
77+
args = ["--image_data_format=channels_last"],
78+
main = "applications_test.py",
5879
shard_count = 50,
5980
tags = [
6081
"no_oss", # b/318174391

tf_keras/applications/applications_test.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818

1919
import tensorflow.compat.v2 as tf
20+
from absl import flags
2021
from absl.testing import parameterized
2122

2223
from tf_keras import backend
@@ -40,6 +41,12 @@
4041
from tf_keras.applications import xception
4142
from tf_keras.testing_infra import test_utils
4243

44+
_IMAGE_DATA_FORMAT = flags.DEFINE_string(
45+
"image_data_format",
46+
"channels_first",
47+
"The image data format to use for the test.",
48+
)
49+
4350
MODEL_LIST_NO_NASNET = [
4451
(resnet.ResNet50, 2048),
4552
(resnet.ResNet101, 2048),
@@ -120,16 +127,6 @@
120127
MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST
121128

122129
MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "NASNet", "RegNetX", "RegNetY"]
123-
# Add each data format for each model
124-
test_parameters_with_image_data_format = [
125-
(
126-
"{}_{}".format(model[0].__name__, image_data_format),
127-
*model,
128-
image_data_format,
129-
)
130-
for image_data_format in ["channels_first", "channels_last"]
131-
for model in MODEL_LIST
132-
]
133130

134131
# Parameters for loading weights for MobileNetV3.
135132
# (class, alpha, minimalistic, include_top)
@@ -183,8 +180,9 @@ def skip_if_invalid_image_data_format_for_model(
183180
"{} does not support channels first".format(app.__name__)
184181
)
185182

186-
@parameterized.named_parameters(test_parameters_with_image_data_format)
187-
def test_application_base(self, app, _, image_data_format):
183+
@parameterized.parameters(*MODEL_LIST)
184+
def test_application_base(self, app, _):
185+
image_data_format = _IMAGE_DATA_FORMAT.value
188186
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
189187
backend.set_image_data_format(image_data_format)
190188
# Can be instantiated with default arguments
@@ -200,8 +198,9 @@ def test_application_base(self, app, _, image_data_format):
200198
self.assertEqual(len(model.weights), len(reconstructed_model.weights))
201199
backend.clear_session()
202200

203-
@parameterized.named_parameters(test_parameters_with_image_data_format)
204-
def test_application_notop(self, app, last_dim, image_data_format):
201+
@parameterized.parameters(*MODEL_LIST)
202+
def test_application_notop(self, app, last_dim):
203+
image_data_format = _IMAGE_DATA_FORMAT.value
205204
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
206205
backend.set_image_data_format(image_data_format)
207206
if image_data_format == "channels_first":
@@ -226,10 +225,9 @@ def test_application_notop(self, app, last_dim, image_data_format):
226225
self.assertShapeEqual(output_shape, correct_output_shape)
227226
backend.clear_session()
228227

229-
@parameterized.named_parameters(test_parameters_with_image_data_format)
230-
def test_application_notop_custom_input_shape(
231-
self, app, last_dim, image_data_format
232-
):
228+
@parameterized.parameters(*MODEL_LIST)
229+
def test_application_notop_custom_input_shape(self, app, last_dim):
230+
image_data_format = _IMAGE_DATA_FORMAT.value
233231
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
234232
backend.set_image_data_format(image_data_format)
235233
if image_data_format == "channels_first":
@@ -261,10 +259,9 @@ def test_application_classifier_activation(self, app, _):
261259
last_layer_act = model.layers[-1].activation.__name__
262260
self.assertEqual(last_layer_act, "softmax")
263261

264-
@parameterized.named_parameters(test_parameters_with_image_data_format)
265-
def test_application_variable_input_channels(
266-
self, app, last_dim, image_data_format
267-
):
262+
@parameterized.parameters(*MODEL_LIST)
263+
def test_application_variable_input_channels(self, app, last_dim):
264+
image_data_format = _IMAGE_DATA_FORMAT.value
268265
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
269266
backend.set_image_data_format(image_data_format)
270267
if backend.image_data_format() == "channels_first":
@@ -303,9 +300,10 @@ def test_mobilenet_v3_load_weights(
303300
include_top=include_top,
304301
)
305302

306-
@parameterized.named_parameters(test_parameters_with_image_data_format)
303+
@parameterized.parameters(*MODEL_LIST)
307304
@test_utils.run_v2_only
308-
def test_model_checkpoint(self, app, _, image_data_format):
305+
def test_model_checkpoint(self, app, _):
306+
image_data_format = _IMAGE_DATA_FORMAT.value
309307
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
310308
backend.set_image_data_format(image_data_format)
311309

0 commit comments

Comments
 (0)