Skip to content

Commit 8d50bf9

Browse files
authored
MAINT: better error message about one-hot encoded targets w/ loss="auto" (#218)
1 parent ca868f5 commit 8d50bf9

File tree

5 files changed

+124
-52
lines changed

5 files changed

+124
-52
lines changed

docs/source/advanced.rst

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
===================================
2-
Advanced Usage of SciKeras Wrappers
3-
===================================
1+
==============
2+
Advanced Usage
3+
==============
44

55
Wrapper Classes
66
---------------
@@ -128,6 +128,43 @@ offer an easy way to compile and tune compilation parameters. Examples:
128128
In all cases, returning an un-compiled model is equivalent to
129129
calling ``model.compile(**compile_kwargs)`` within ``model_build_fn``.
130130

131+
.. _loss-selection:
132+
133+
Loss selection
134+
++++++++++++++
135+
136+
If you do not explicitly define a loss, SciKeras attempts to find a loss
137+
that matches the type of target (see :py:func:`sklearn.utils.multiclass.type_of_target`).
138+
139+
For guidance selecting losses in Keras, please see Jason Brownlee's
140+
excellent article `How to Choose Loss Functions When Training Deep Learning Neural Networks`_
141+
as well as `Keras Losses docs`_.
142+
143+
Default losses are selected as follows:
144+
145+
Classification
146+
..............
147+
148+
+-----------+-----------+----------+---------------------------------+
149+
| # outputs | # classes | encoding | loss |
150+
+===========+===========+==========+=================================+
151+
| 1 | <= 2 | any | binary crossentropy |
152+
+-----------+-----------+----------+---------------------------------+
153+
| 1 | >=2 | labels | sparse categorical crossentropy |
154+
+-----------+-----------+----------+---------------------------------+
155+
| 1 | >=2 | one-hot | unsupported |
156+
+-----------+-----------+----------+---------------------------------+
157+
| > 1 | -- | -- | unsupported |
158+
+-----------+-----------+----------+---------------------------------+
159+
160+
Note that SciKeras will not automatically infer the loss for one-hot encoded targets,
161+
you would need to explicitly specify `loss="categorical_crossentropy"`.
162+
163+
Regression
164+
..........
165+
166+
Regression always defaults to mean squared error.
167+
For multi-output models, Keras will use the sum of each output's loss.
131168

132169
Arguments to ``model_build_fn``
133170
-------------------------------
@@ -287,3 +324,7 @@ and :class:`scikeras.wrappers.KerasRegressor` respectively. To override these sc
287324
.. _Keras Callbacks docs: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks
288325

289326
.. _Keras Metrics docs: https://www.tensorflow.org/api_docs/python/tf/keras/metrics
327+
328+
.. _Keras Losses docs: https://www.tensorflow.org/api_docs/python/tf/keras/losses
329+
330+
.. _How to Choose Loss Functions When Training Deep Learning Neural Networks: https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/

docs/source/quickstart.rst

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,25 @@ it on a toy classification dataset using SciKeras
3838
model.add(keras.layers.Activation("softmax"))
3939
return model
4040
41-
clf = KerasClassifier(
42-
get_model,
43-
loss="sparse_categorical_crossentropy",
44-
hidden_layer_dim=100,
45-
)
41+
clf = KerasClassifier(get_model, hidden_layer_dim=100)
4642
4743
clf.fit(X, y)
4844
y_proba = clf.predict_proba(X)
4945
5046
47+
Note that SciKeras even chooses a loss function and compiles your model.
48+
To override the default loss, simply specify a loss function:
49+
50+
.. code-block:: diff
51+
52+
-KerasClassifier(get_model, hidden_layer_dim=100)
53+
+KerasClassifier(get_model, loss="categorical_crossentropy")
54+
55+
In this case, you would need to specify the loss since SciKeras
56+
will not default to categorical crossentropy, even for one-hot
57+
encoded targets.
58+
See :ref:`loss-selection` for more details.
59+
5160
In an sklearn Pipeline
5261
----------------------
5362

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ version = "0.2.1"
3030

3131
[tool.poetry.dependencies]
3232
importlib-metadata = {version = "^3.4.0", python = "<3.8"}
33-
python = ">=3.6.7, <3.9"
33+
python = "^3.11.0"
3434
scikit-learn = "^0.22.0"
3535
tensorflow = "^2.4.0"
3636

scikeras/wrappers.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from tensorflow.keras import optimizers as optimizers_module
2323
from tensorflow.keras.models import Model
2424
from tensorflow.keras.utils import register_keras_serializable
25-
from tensorflow.python.types.core import Value
2625

2726
from scikeras._utils import (
2827
TFRandomState,
@@ -1328,24 +1327,37 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None:
13281327
raise ValueError(
13291328
'Only single-output models are supported with `loss="auto"`'
13301329
)
1330+
loss = None
1331+
hint = ""
13311332
if self.target_type_ == "binary":
13321333
if self.model_.outputs[0].shape[1] != 1:
13331334
raise ValueError(
13341335
"Binary classification expects a model with exactly 1 output unit."
13351336
)
1336-
compile_kwargs["loss"] = "binary_crossentropy"
1337+
loss = "binary_crossentropy"
13371338
elif self.target_type_ == "multiclass":
13381339
if self.model_.outputs[0].shape[1] == 1:
13391340
raise ValueError(
13401341
"Multi-class targets require the model to have >1 output units."
13411342
)
1342-
compile_kwargs["loss"] = "sparse_categorical_crossentropy"
1343-
else:
1344-
raise NotImplementedError(
1343+
loss = "sparse_categorical_crossentropy"
1344+
elif self.target_type_ == "multilabel-indicator":
1345+
# one-hot encoded multiclass problem OR multilabel-indicator problem
1346+
hint = (
1347+
"For this type of problem, the following may help:"
1348+
'\n - If there is only one class per example, loss="categorical_crossentropy" might be appropriate.'
1349+
'\n - If there are multiple classes per example, loss="binary_crossentropy" might be appropriate.'
1350+
)
1351+
if loss is None:
1352+
msg = (
13451353
f'`loss="auto"` is not supported for tasks of type {self.target_type_}.'
1346-
" Instead, you must explicitly pass a loss function, for example:"
1354+
"\nInstead, you must compile the model yourself or explicitly pass a loss function, for example:"
13471355
'\n clf = KerasClassifier(..., loss="categorical_crossentropy")'
13481356
)
1357+
if hint:
1358+
msg += f"\n\n{hint}"
1359+
raise NotImplementedError(msg)
1360+
compile_kwargs["loss"] = loss
13491361
self.model_.compile(**compile_kwargs)
13501362

13511363
@staticmethod

tests/test_loss_auto.py

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
X = np.random.uniform(size=(n_eg, FEATURES)).astype("float32")
1515

1616

17-
def shallow_net(single_output=False, loss=None, compile=False):
17+
def shallow_net(outputs=None, loss=None, compile=False):
1818
model = tf.keras.Sequential()
1919
model.add(tf.keras.layers.Input(shape=(FEATURES,)))
20-
if single_output:
21-
model.add(tf.keras.layers.Dense(1))
22-
else:
20+
if outputs is None:
2321
model.add(tf.keras.layers.Dense(N_CLASSES))
22+
else:
23+
model.add(tf.keras.layers.Dense(outputs))
2424

2525
if compile:
2626
model.compile(loss=loss)
@@ -45,7 +45,7 @@ def test_user_compiled(loss):
4545
"""Test to make sure that user compiled classification models work with all
4646
classification losses.
4747
"""
48-
model__single_output = True if "binary" in loss else False
48+
model__outputs = 1 if "binary" in loss else None
4949
if loss == "binary_crosentropy":
5050
y = np.random.randint(0, 2, size=(n_eg,))
5151
elif loss == "categorical_crossentropy":
@@ -59,7 +59,7 @@ def test_user_compiled(loss):
5959
shallow_net,
6060
model__compile=True,
6161
model__loss=loss,
62-
model__single_output=model__single_output,
62+
model__outputs=model__outputs,
6363
)
6464
est.partial_fit(X, y)
6565

@@ -69,7 +69,7 @@ def test_user_compiled(loss):
6969

7070
class NoEncoderClf(KerasClassifier):
7171
"""A classifier overriding default target encoding.
72-
This simulates a user implementing custom encoding logic in
72+
This simulates a user implementing custom encoding logic in
7373
target_encoder to support multiclass-multioutput or
7474
multilabel-indicator, which by default would raise an error.
7575
"""
@@ -79,40 +79,58 @@ def target_encoder(self):
7979
return FunctionTransformer()
8080

8181

82-
@pytest.mark.parametrize("use_case", ["multilabel-indicator", "multiclass-multioutput"])
83-
def test_classifier_unsupported_multi_output_tasks(use_case):
82+
@pytest.mark.parametrize(
83+
"use_case,wrapper_cls",
84+
[
85+
("multilabel-indicator", NoEncoderClf),
86+
("multiclass-multioutput", NoEncoderClf),
87+
("classification_w_onehot_targets", KerasClassifier),
88+
],
89+
)
90+
def test_classifier_unsupported_multi_output_tasks(use_case, wrapper_cls):
8491
"""Test for an appropriate error for tasks that are not supported
8592
by `loss="auto"`.
8693
"""
94+
extra = ""
95+
fix_loss = None
8796
if use_case == "multiclass-multioutput":
8897
y1 = np.random.randint(0, 1, size=len(X))
8998
y2 = np.random.randint(0, 2, size=len(X))
9099
y = np.column_stack([y1, y2])
91100
elif use_case == "multilabel-indicator":
92101
y1 = np.random.randint(0, 1, size=len(X))
93102
y = np.column_stack([y1, y1])
94-
est = NoEncoderClf(shallow_net, model__compile=False)
95-
with pytest.raises(
96-
NotImplementedError, match='`loss="auto"` is not supported for tasks of type'
97-
):
98-
est.initialize(X, y)
103+
y[0, :] = 1
104+
fix_loss = "binary_crossentropy"
105+
extra = f'loss="{fix_loss}" might be appropriate'
106+
elif use_case == "classification_w_onehot_targets":
107+
y = np.random.choice(N_CLASSES, size=len(X)).astype(int)
108+
y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1))
109+
fix_loss = "categorical_crossentropy"
110+
extra = f'loss="{fix_loss}" might be appropriate'
111+
match = '`loss="auto"` is not supported for tasks of type'
112+
if extra:
113+
match += f"(.|\n)+{extra}"
114+
with pytest.raises(NotImplementedError, match=match):
115+
wrapper_cls(shallow_net, model__compile=False).initialize(X, y)
116+
if fix_loss:
117+
wrapper_cls(shallow_net, model__compile=False, loss=fix_loss).initialize(X, y)
99118

100119

101120
@pytest.mark.parametrize(
102-
"use_case,supported",
121+
"use_case",
103122
[
104-
("binary_classification", True),
105-
("binary_classification_w_one_class", True),
106-
("classification_w_1d_targets", True),
107-
("classification_w_onehot_targets", False),
123+
"binary_classification",
124+
"binary_classification_w_one_class",
125+
"classification_w_1d_targets",
108126
],
109127
)
110-
def test_classifier_default_loss_only_model_specified(use_case, supported):
128+
def test_classifier_default_loss_only_model_specified(use_case):
111129
"""Test that KerasClassifier will auto-determine a loss function
112130
when only the model is specified.
113131
"""
114132

115-
model__single_output = True if "binary" in use_case else False
133+
model__outputs = 1 if "binary" in use_case else None
116134
if use_case == "binary_classification":
117135
exp_loss = "binary_crossentropy"
118136
y = np.random.choice(2, size=len(X)).astype(int)
@@ -122,21 +140,11 @@ def test_classifier_default_loss_only_model_specified(use_case, supported):
122140
elif use_case == "classification_w_1d_targets":
123141
exp_loss = "sparse_categorical_crossentropy"
124142
y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int)
125-
elif use_case == "classification_w_onehot_targets":
126-
y = np.random.choice(N_CLASSES, size=len(X)).astype(int)
127-
y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1))
128143

129-
est = KerasClassifier(model=shallow_net, model__single_output=model__single_output)
144+
est = KerasClassifier(model=shallow_net, model__outputs=model__outputs)
130145

131-
if supported:
132-
est.fit(X, y=y)
133-
assert loss_name(est.model_.loss) == exp_loss
134-
else:
135-
with pytest.raises(
136-
NotImplementedError,
137-
match='`loss="auto"` is not supported for tasks of type',
138-
):
139-
est.fit(X, y=y)
146+
est.fit(X, y=y)
147+
assert loss_name(est.model_.loss) == exp_loss
140148
assert est.loss == "auto"
141149

142150

@@ -148,7 +156,9 @@ def test_regressor_default_loss_only_model_specified(use_case):
148156
y = np.random.uniform(size=len(X))
149157
if use_case == "multi_output":
150158
y = np.column_stack([y, y])
151-
est = KerasRegressor(model=shallow_net, model__single_output=True)
159+
est = KerasRegressor(
160+
model=shallow_net, model__outputs=1 if "single" in use_case else 2
161+
)
152162
est.fit(X, y)
153163
assert est.loss == "auto"
154164
assert loss_name(est.model_.loss) == "mean_squared_error"
@@ -202,7 +212,7 @@ def test_multi_output_support(user_compiled, est_cls):
202212
def test_multiclass_single_output_unit():
203213
"""Test that multiclass targets requires > 1 output units.
204214
"""
205-
est = KerasClassifier(model=shallow_net, model__single_output=True)
215+
est = KerasClassifier(model=shallow_net, model__outputs=1)
206216
y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int)
207217
with pytest.raises(
208218
ValueError,
@@ -214,7 +224,7 @@ def test_multiclass_single_output_unit():
214224
def test_binary_multiple_output_units():
215225
"""Test that binary targets requires exactly 1 output unit.
216226
"""
217-
est = KerasClassifier(model=shallow_net, model__single_output=False)
227+
est = KerasClassifier(model=shallow_net, model__outputs=2)
218228
y = np.random.choice(2, size=len(X)).astype(int)
219229
with pytest.raises(
220230
ValueError,

0 commit comments

Comments
 (0)