Skip to content

Commit ca868f5

Browse files
committed
add test for number of output units
1 parent c7b567f commit ca868f5

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

scikeras/wrappers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1329,11 +1329,15 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None:
13291329
'Only single-output models are supported with `loss="auto"`'
13301330
)
13311331
if self.target_type_ == "binary":
1332+
if self.model_.outputs[0].shape[1] != 1:
1333+
raise ValueError(
1334+
"Binary classification expects a model with exactly 1 output unit."
1335+
)
13321336
compile_kwargs["loss"] = "binary_crossentropy"
13331337
elif self.target_type_ == "multiclass":
13341338
if self.model_.outputs[0].shape[1] == 1:
13351339
raise ValueError(
1336-
f"Multi-class targets require the model to have >1 output unit instead of {self.model_.outputs[0].shape} units"
1340+
"Multi-class targets require the model to have >1 output units."
13371341
)
13381342
compile_kwargs["loss"] = "sparse_categorical_crossentropy"
13391343
else:

tests/test_loss_auto.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,27 @@ def test_multi_output_support(user_compiled, est_cls):
197197
match='Only single-output models are supported with `loss="auto"`',
198198
):
199199
est.fit(X, y)
200+
201+
202+
def test_multiclass_single_output_unit():
203+
"""Test that multiclass targets requires > 1 output units.
204+
"""
205+
est = KerasClassifier(model=shallow_net, model__single_output=True)
206+
y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int)
207+
with pytest.raises(
208+
ValueError,
209+
match="Multi-class targets require the model to have >1 output units",
210+
):
211+
est.fit(X, y)
212+
213+
214+
def test_binary_multiple_output_units():
215+
"""Test that binary targets requires exactly 1 output unit.
216+
"""
217+
est = KerasClassifier(model=shallow_net, model__single_output=False)
218+
y = np.random.choice(2, size=len(X)).astype(int)
219+
with pytest.raises(
220+
ValueError,
221+
match="Binary classification expects a model with exactly 1 output unit",
222+
):
223+
est.fit(X, y)

0 commit comments

Comments
 (0)