14
14
X = np .random .uniform (size = (n_eg , FEATURES )).astype ("float32" )
15
15
16
16
17
- def shallow_net (single_output = False , loss = None , compile = False ):
17
+ def shallow_net (outputs = None , loss = None , compile = False ):
18
18
model = tf .keras .Sequential ()
19
19
model .add (tf .keras .layers .Input (shape = (FEATURES ,)))
20
- if single_output :
21
- model .add (tf .keras .layers .Dense (1 ))
22
- else :
20
+ if outputs is None :
23
21
model .add (tf .keras .layers .Dense (N_CLASSES ))
22
+ else :
23
+ model .add (tf .keras .layers .Dense (outputs ))
24
24
25
25
if compile :
26
26
model .compile (loss = loss )
@@ -45,7 +45,7 @@ def test_user_compiled(loss):
45
45
"""Test to make sure that user compiled classification models work with all
46
46
classification losses.
47
47
"""
48
- model__single_output = True if "binary" in loss else False
48
+ model__outputs = 1 if "binary" in loss else None
49
49
if loss == "binary_crosentropy" :
50
50
y = np .random .randint (0 , 2 , size = (n_eg ,))
51
51
elif loss == "categorical_crossentropy" :
@@ -59,7 +59,7 @@ def test_user_compiled(loss):
59
59
shallow_net ,
60
60
model__compile = True ,
61
61
model__loss = loss ,
62
- model__single_output = model__single_output ,
62
+ model__outputs = model__outputs ,
63
63
)
64
64
est .partial_fit (X , y )
65
65
@@ -69,7 +69,7 @@ def test_user_compiled(loss):
69
69
70
70
class NoEncoderClf (KerasClassifier ):
71
71
"""A classifier overriding default target encoding.
72
- This simulates a user implementing custom encoding logic in
72
+ This simulates a user implementing custom encoding logic in
73
73
target_encoder to support multiclass-multioutput or
74
74
multilabel-indicator, which by default would raise an error.
75
75
"""
@@ -79,40 +79,58 @@ def target_encoder(self):
79
79
return FunctionTransformer ()
80
80
81
81
82
- @pytest .mark .parametrize ("use_case" , ["multilabel-indicator" , "multiclass-multioutput" ])
83
- def test_classifier_unsupported_multi_output_tasks (use_case ):
82
+ @pytest .mark .parametrize (
83
+ "use_case,wrapper_cls" ,
84
+ [
85
+ ("multilabel-indicator" , NoEncoderClf ),
86
+ ("multiclass-multioutput" , NoEncoderClf ),
87
+ ("classification_w_onehot_targets" , KerasClassifier ),
88
+ ],
89
+ )
90
+ def test_classifier_unsupported_multi_output_tasks (use_case , wrapper_cls ):
84
91
"""Test for an appropriate error for tasks that are not supported
85
92
by `loss="auto"`.
86
93
"""
94
+ extra = ""
95
+ fix_loss = None
87
96
if use_case == "multiclass-multioutput" :
88
97
y1 = np .random .randint (0 , 1 , size = len (X ))
89
98
y2 = np .random .randint (0 , 2 , size = len (X ))
90
99
y = np .column_stack ([y1 , y2 ])
91
100
elif use_case == "multilabel-indicator" :
92
101
y1 = np .random .randint (0 , 1 , size = len (X ))
93
102
y = np .column_stack ([y1 , y1 ])
94
- est = NoEncoderClf (shallow_net , model__compile = False )
95
- with pytest .raises (
96
- NotImplementedError , match = '`loss="auto"` is not supported for tasks of type'
97
- ):
98
- est .initialize (X , y )
103
+ y [0 , :] = 1
104
+ fix_loss = "binary_crossentropy"
105
+ extra = f'loss="{ fix_loss } " might be appropriate'
106
+ elif use_case == "classification_w_onehot_targets" :
107
+ y = np .random .choice (N_CLASSES , size = len (X )).astype (int )
108
+ y = OneHotEncoder (sparse = False ).fit_transform (y .reshape (- 1 , 1 ))
109
+ fix_loss = "categorical_crossentropy"
110
+ extra = f'loss="{ fix_loss } " might be appropriate'
111
+ match = '`loss="auto"` is not supported for tasks of type'
112
+ if extra :
113
+ match += f"(.|\n )+{ extra } "
114
+ with pytest .raises (NotImplementedError , match = match ):
115
+ wrapper_cls (shallow_net , model__compile = False ).initialize (X , y )
116
+ if fix_loss :
117
+ wrapper_cls (shallow_net , model__compile = False , loss = fix_loss ).initialize (X , y )
99
118
100
119
101
120
@pytest .mark .parametrize (
102
- "use_case,supported " ,
121
+ "use_case" ,
103
122
[
104
- ("binary_classification" , True ),
105
- ("binary_classification_w_one_class" , True ),
106
- ("classification_w_1d_targets" , True ),
107
- ("classification_w_onehot_targets" , False ),
123
+ "binary_classification" ,
124
+ "binary_classification_w_one_class" ,
125
+ "classification_w_1d_targets" ,
108
126
],
109
127
)
110
- def test_classifier_default_loss_only_model_specified (use_case , supported ):
128
+ def test_classifier_default_loss_only_model_specified (use_case ):
111
129
"""Test that KerasClassifier will auto-determine a loss function
112
130
when only the model is specified.
113
131
"""
114
132
115
- model__single_output = True if "binary" in use_case else False
133
+ model__outputs = 1 if "binary" in use_case else None
116
134
if use_case == "binary_classification" :
117
135
exp_loss = "binary_crossentropy"
118
136
y = np .random .choice (2 , size = len (X )).astype (int )
@@ -122,21 +140,11 @@ def test_classifier_default_loss_only_model_specified(use_case, supported):
122
140
elif use_case == "classification_w_1d_targets" :
123
141
exp_loss = "sparse_categorical_crossentropy"
124
142
y = np .random .choice (N_CLASSES , size = (len (X ), 1 )).astype (int )
125
- elif use_case == "classification_w_onehot_targets" :
126
- y = np .random .choice (N_CLASSES , size = len (X )).astype (int )
127
- y = OneHotEncoder (sparse = False ).fit_transform (y .reshape (- 1 , 1 ))
128
143
129
- est = KerasClassifier (model = shallow_net , model__single_output = model__single_output )
144
+ est = KerasClassifier (model = shallow_net , model__outputs = model__outputs )
130
145
131
- if supported :
132
- est .fit (X , y = y )
133
- assert loss_name (est .model_ .loss ) == exp_loss
134
- else :
135
- with pytest .raises (
136
- NotImplementedError ,
137
- match = '`loss="auto"` is not supported for tasks of type' ,
138
- ):
139
- est .fit (X , y = y )
146
+ est .fit (X , y = y )
147
+ assert loss_name (est .model_ .loss ) == exp_loss
140
148
assert est .loss == "auto"
141
149
142
150
@@ -148,7 +156,9 @@ def test_regressor_default_loss_only_model_specified(use_case):
148
156
y = np .random .uniform (size = len (X ))
149
157
if use_case == "multi_output" :
150
158
y = np .column_stack ([y , y ])
151
- est = KerasRegressor (model = shallow_net , model__single_output = True )
159
+ est = KerasRegressor (
160
+ model = shallow_net , model__outputs = 1 if "single" in use_case else 2
161
+ )
152
162
est .fit (X , y )
153
163
assert est .loss == "auto"
154
164
assert loss_name (est .model_ .loss ) == "mean_squared_error"
@@ -202,7 +212,7 @@ def test_multi_output_support(user_compiled, est_cls):
202
212
def test_multiclass_single_output_unit ():
203
213
"""Test that multiclass targets requires > 1 output units.
204
214
"""
205
- est = KerasClassifier (model = shallow_net , model__single_output = True )
215
+ est = KerasClassifier (model = shallow_net , model__outputs = 1 )
206
216
y = np .random .choice (N_CLASSES , size = (len (X ), 1 )).astype (int )
207
217
with pytest .raises (
208
218
ValueError ,
@@ -214,7 +224,7 @@ def test_multiclass_single_output_unit():
214
224
def test_binary_multiple_output_units ():
215
225
"""Test that binary targets requires exactly 1 output unit.
216
226
"""
217
- est = KerasClassifier (model = shallow_net , model__single_output = False )
227
+ est = KerasClassifier (model = shallow_net , model__outputs = 2 )
218
228
y = np .random .choice (2 , size = len (X )).astype (int )
219
229
with pytest .raises (
220
230
ValueError ,
0 commit comments