1717import os
1818
1919import tensorflow .compat .v2 as tf
20+ from absl import flags
2021from absl .testing import parameterized
2122
2223from tf_keras import backend
4041from tf_keras .applications import xception
4142from tf_keras .testing_infra import test_utils
4243
44+ _IMAGE_DATA_FORMAT = flags .DEFINE_string (
45+ "image_data_format" ,
46+ "channels_first" ,
47+ "The image data format to use for the test." ,
48+ )
49+
4350MODEL_LIST_NO_NASNET = [
4451 (resnet .ResNet50 , 2048 ),
4552 (resnet .ResNet101 , 2048 ),
120127MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST
121128
122129MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt" , "NASNet" , "RegNetX" , "RegNetY" ]
123- # Add each data format for each model
124- test_parameters_with_image_data_format = [
125- (
126- "{}_{}" .format (model [0 ].__name__ , image_data_format ),
127- * model ,
128- image_data_format ,
129- )
130- for image_data_format in ["channels_first" , "channels_last" ]
131- for model in MODEL_LIST
132- ]
133130
134131# Parameters for loading weights for MobileNetV3.
135132# (class, alpha, minimalistic, include_top)
@@ -183,8 +180,9 @@ def skip_if_invalid_image_data_format_for_model(
183180 "{} does not support channels first" .format (app .__name__ )
184181 )
185182
186- @parameterized .named_parameters (test_parameters_with_image_data_format )
187- def test_application_base (self , app , _ , image_data_format ):
183+ @parameterized .parameters (* MODEL_LIST )
184+ def test_application_base (self , app , _ ):
185+ image_data_format = _IMAGE_DATA_FORMAT .value
188186 self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
189187 backend .set_image_data_format (image_data_format )
190188 # Can be instantiated with default arguments
@@ -200,8 +198,9 @@ def test_application_base(self, app, _, image_data_format):
200198 self .assertEqual (len (model .weights ), len (reconstructed_model .weights ))
201199 backend .clear_session ()
202200
203- @parameterized .named_parameters (test_parameters_with_image_data_format )
204- def test_application_notop (self , app , last_dim , image_data_format ):
201+ @parameterized .parameters (* MODEL_LIST )
202+ def test_application_notop (self , app , last_dim ):
203+ image_data_format = _IMAGE_DATA_FORMAT .value
205204 self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
206205 backend .set_image_data_format (image_data_format )
207206 if image_data_format == "channels_first" :
@@ -226,10 +225,9 @@ def test_application_notop(self, app, last_dim, image_data_format):
226225 self .assertShapeEqual (output_shape , correct_output_shape )
227226 backend .clear_session ()
228227
229- @parameterized .named_parameters (test_parameters_with_image_data_format )
230- def test_application_notop_custom_input_shape (
231- self , app , last_dim , image_data_format
232- ):
228+ @parameterized .parameters (* MODEL_LIST )
229+ def test_application_notop_custom_input_shape (self , app , last_dim ):
230+ image_data_format = _IMAGE_DATA_FORMAT .value
233231 self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
234232 backend .set_image_data_format (image_data_format )
235233 if image_data_format == "channels_first" :
@@ -261,10 +259,9 @@ def test_application_classifier_activation(self, app, _):
261259 last_layer_act = model .layers [- 1 ].activation .__name__
262260 self .assertEqual (last_layer_act , "softmax" )
263261
264- @parameterized .named_parameters (test_parameters_with_image_data_format )
265- def test_application_variable_input_channels (
266- self , app , last_dim , image_data_format
267- ):
262+ @parameterized .parameters (* MODEL_LIST )
263+ def test_application_variable_input_channels (self , app , last_dim ):
264+ image_data_format = _IMAGE_DATA_FORMAT .value
268265 self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
269266 backend .set_image_data_format (image_data_format )
270267 if backend .image_data_format () == "channels_first" :
@@ -303,9 +300,10 @@ def test_mobilenet_v3_load_weights(
303300 include_top = include_top ,
304301 )
305302
306- @parameterized .named_parameters ( test_parameters_with_image_data_format )
303+ @parameterized .parameters ( * MODEL_LIST )
307304 @test_utils .run_v2_only
308- def test_model_checkpoint (self , app , _ , image_data_format ):
305+ def test_model_checkpoint (self , app , _ ):
306+ image_data_format = _IMAGE_DATA_FORMAT .value
309307 self .skip_if_invalid_image_data_format_for_model (app , image_data_format )
310308 backend .set_image_data_format (image_data_format )
311309
0 commit comments