Skip to content

Commit d83bffa

Browse files
authored
Soften deprecation warning for **kwargs and add more guidance (#198)
1 parent b26a22d commit d83bffa

File tree

4 files changed

+150
-113
lines changed

4 files changed

+150
-113
lines changed

docs/source/migration.rst

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,17 @@ pass your loss function to the constructor:
4949
Variable keyword arguments in fit and predict
5050
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
5151

52-
In a future release of SciKeras, variable keyword arguments (commonly referred to as
53-
``**kwargs``) will be removed from fit and predict. To future
54-
proof your code, you should instead declare these parameters in your constructor:
52+
Keras supports a variable keyword arguments (commonly referred to as ``**kwargs``) for ``fit`` and ``predict``.
53+
Scikit-Learn on the other hand does not support these arguments, and using them is largely incompatible with the Scikit-Learn ecosystem.
54+
As a compromise, SciKeras supports these arguments, but we recommended that you set parameters using the constructor
55+
or ``set_params`` for first-class SciKeras support.
56+
57+
.. warning::
58+
59+
Passing keyword arguments to ``fit`` and ``predict`` is deprecated and will be removed in a future version of SciKeras.
60+
61+
62+
For example, to declare ``batch_size`` in the constructor:
5563

5664
.. code:: diff
5765
@@ -64,7 +72,19 @@ Or to declare separate values for ``fit`` and ``predict``:
6472

6573
.. code:: python
6674
67-
clf = KerasClassifier(fit__batch_size=32, predict__batch_size=10000)
75+
clf = KerasClassifier(..., fit__batch_size=32, predict__batch_size=10000)
76+
77+
If you want to change the parameters on a live instance, you can do:
78+
79+
.. code:: python
80+
81+
clf = KerasClassifier(...)
82+
clf.set_params(fit__batch_size=32, predict__batch_size=10000)
83+
clf.fit(...)
84+
85+
Functionally, this is the same as passing these parameters to ``fit``, just with one more function call.
86+
This is much more compatible with the Scikit-Learn API.
87+
In fact, this is what Scikti-Learn does in the background for hyperparameter tuning.
6888

6989
Renaming of ``build_fn`` to ``model``
7090
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -98,4 +118,4 @@ tunable parameters (i.e. settable via ``set_params``):
98118
+ clf = KerasClassifier(get_model, model__my_param=123) # option 2
99119
100120
That said, if you do not need them to work with ``set_params`` (which is only really
101-
necessary if you are doing hyperparameter tuning), you do not need to make any changes.
121+
necessary if you are doing hyperparameter tuning), you do not need to make any changes.

scikeras/wrappers.py

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Wrapper for using the Scikit-Learn API with Keras models.
22
"""
33
import inspect
4-
import os
54
import warnings
65

76
from collections import defaultdict
@@ -36,12 +35,21 @@
3635
from scikeras.utils.transformers import ClassifierLabelEncoder, RegressorTargetEncoder
3736

3837

38+
_kwarg_warn = """Passing estimator parameters as keyword arguments (aka as `**kwargs`) to `{0}` is not supported by the Scikit-Learn API, and will be removed in a future version of SciKeras.
39+
40+
To resolve this issue, either set these parameters in the constructor (e.g., `est = BaseWrapper(..., foo=bar)`) or via `set_params` (e.g., `est.set_params(foo=bar)`). The following parameters were passed to `{0}`:
41+
42+
{1}
43+
44+
More detail is available at https://www.adriangb.com/scikeras/migration.html#variable-keyword-arguments-in-fit-and-predict
45+
"""
46+
47+
3948
class BaseWrapper(BaseEstimator):
4049
"""Implementation of the scikit-learn classifier API for Keras.
4150
4251
Below are a list of SciKeras specific parameters. For details on other parameters,
43-
please see the see the
44-
[tf.keras.Model documentation](https://www.tensorflow.org/api_docs/python/tf/keras/Model).
52+
please see the see the `tf.keras.Model documentation <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`_.
4553
4654
Parameters
4755
----------
@@ -696,20 +704,23 @@ def fit(self, X, y, sample_weight=None, **kwargs) -> "BaseWrapper":
696704
If not provided, then each sample is given unit weight.
697705
**kwargs : Dict[str, Any]
698706
Extra arguments to route to ``Model.fit``.
699-
This functionality has been deprecated, and will be removed in SciKeras 1.0.0.
700-
These parameters can also be specified by prefixing `fit__` to a parameter at initialization;
701-
e.g, `BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)`.
707+
708+
Warnings
709+
--------
710+
Passing estimator parameters as keyword arguments (aka as ``**kwargs``) to ``fit`` is not supported by the Scikit-Learn API,
711+
and will be removed in a future version of SciKeras.
712+
These parameters can also be specified by prefixing ``fit__`` to a parameter at initialization
713+
(``BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)``)
714+
or by using ``set_params`` (``est.set_params(fit__batch_size=32, predict__batch_size=1000)``).
715+
702716
Returns
703717
-------
704718
BaseWrapper
705-
A reference to the instance that can be chain called
706-
(ex: instance.fit(X,y).transform(X) )
719+
A reference to the instance that can be chain called (``est.fit(X,y).transform(X)``).
707720
"""
708-
for k, v in kwargs.items():
709-
warnings.warn(
710-
"``**kwargs`` has been deprecated in SciKeras 0.2.1 and support will be removed be 1.0.0."
711-
f" Instead, set fit arguments at initialization (i.e., ``BaseWrapper({k}={v})``)"
712-
)
721+
if kwargs:
722+
kwarg_list = "\n * ".join([f"`{k}={v}`" for k, v in kwargs.items()])
723+
warnings.warn(_kwarg_warn.format("fit", kwarg_list))
713724

714725
# epochs via kwargs > fit__epochs > epochs
715726
kwargs["epochs"] = kwargs.get(
@@ -886,11 +897,9 @@ def _predict_raw(self, X, **kwargs):
886897
For classification, this corresponds to predict_proba.
887898
For regression, this corresponds to predict.
888899
"""
889-
for k, v in kwargs.items():
890-
warnings.warn(
891-
"``**kwargs`` has been deprecated in SciKeras 0.2.1 and support will be removed be 1.0.0."
892-
f" Instead, set predict arguments at initialization (i.e., ``BaseWrapper({k}={v})``)"
893-
)
900+
if kwargs:
901+
kwarg_list = "\n * ".join([f"`{k}={v}`" for k, v in kwargs.items()])
902+
warnings.warn(_kwarg_warn.format("predict", kwarg_list), stacklevel=2)
894903

895904
# check if fitted
896905
if not self.initialized_:
@@ -925,9 +934,14 @@ def predict(self, X, **kwargs):
925934
and n_features is the number of features.
926935
**kwargs : Dict[str, Any]
927936
Extra arguments to route to ``Model.predict``.
928-
This functionality has been deprecated, and will be removed in SciKeras 1.0.0.
929-
These parameters can also be specified by prefixing `predict__` to a parameter at initialization;
930-
e.g, `BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)`.
937+
938+
Warnings
939+
--------
940+
Passing estimator parameters as keyword arguments (aka as ``**kwargs``) to ``predict`` is not supported by the Scikit-Learn API,
941+
and will be removed in a future version of SciKeras.
942+
These parameters can also be specified by prefixing ``predict__`` to a parameter at initialization
943+
(``BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)``)
944+
or by using ``set_params`` (``est.set_params(fit__batch_size=32, predict__batch_size=1000)``).
931945
932946
Returns
933947
-------
@@ -1090,8 +1104,7 @@ class KerasClassifier(BaseWrapper):
10901104
"""Implementation of the scikit-learn classifier API for Keras.
10911105
10921106
Below are a list of SciKeras specific parameters. For details on other parameters,
1093-
please see the see the
1094-
[tf.keras.Model documentation](https://www.tensorflow.org/api_docs/python/tf/keras/Model).
1107+
please see the see the `tf.keras.Model documentation <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`_.
10951108
10961109
Parameters
10971110
----------
@@ -1351,15 +1364,19 @@ def fit(self, X, y, sample_weight=None, **kwargs) -> "KerasClassifier":
13511364
If not provided, then each sample is given unit weight.
13521365
**kwargs : Dict[str, Any]
13531366
Extra arguments to route to ``Model.fit``.
1354-
This functionality has been deprecated, and will be removed in SciKeras 1.0.0.
1355-
These parameters can also be specified by prefixing `fit__` to a parameter at initialization;
1356-
e.g, `BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)`.
1367+
1368+
Warnings
1369+
--------
1370+
Passing estimator parameters as keyword arguments (aka as ``**kwargs``) to ``fit`` is not supported by the Scikit-Learn API,
1371+
and will be removed in a future version of SciKeras.
1372+
These parameters can also be specified by prefixing ``fit__`` to a parameter at initialization
1373+
(``KerasClassifier(..., fit__batch_size=32, predict__batch_size=1000)``)
1374+
or by using ``set_params`` (``est.set_params(fit__batch_size=32, predict__batch_size=1000)``).
13571375
13581376
Returns
13591377
-------
13601378
KerasClassifier
1361-
A reference to the instance that can be chain called
1362-
(ex: instance.fit(X,y).transform(X) )
1379+
A reference to the instance that can be chain called (``est.fit(X,y).transform(X)``).
13631380
"""
13641381
self.classes_ = None
13651382
if self.class_weight is not None:
@@ -1415,9 +1432,14 @@ def predict_proba(self, X, **kwargs):
14151432
and n_features is the number of features.
14161433
**kwargs : Dict[str, Any]
14171434
Extra arguments to route to ``Model.predict``.
1418-
This functionality has been deprecated, and will be removed in SciKeras 1.0.0.
1419-
These parameters can also be specified by prefixing `predict__` to a parameter at initialization;
1420-
e.g, `BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)`.
1435+
1436+
Warnings
1437+
--------
1438+
Passing estimator parameters as keyword arguments (aka as ``**kwargs``) to ``predict_proba`` is not supported by the Scikit-Learn API,
1439+
and will be removed in a future version of SciKeras.
1440+
These parameters can also be specified by prefixing ``predict__`` to a parameter at initialization
1441+
(``KerasClassifier(..., fit__batch_size=32, predict__batch_size=1000)``)
1442+
or by using ``set_params`` (``est.set_params(fit__batch_size=32, predict__batch_size=1000)``).
14211443
14221444
Returns
14231445
-------
@@ -1441,8 +1463,7 @@ class KerasRegressor(BaseWrapper):
14411463
"""Implementation of the scikit-learn classifier API for Keras.
14421464
14431465
Below are a list of SciKeras specific parameters. For details on other parameters,
1444-
please see the see the
1445-
[tf.keras.Model documentation](https://www.tensorflow.org/api_docs/python/tf/keras/Model).
1466+
please see the see the `tf.keras.Model documentation <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`_.
14461467
14471468
Parameters
14481469
----------

tests/test_deprecation.py

Lines changed: 2 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
"""Tests for features scheduled for deprecation.
22
"""
3-
from unittest import mock
4-
5-
import numpy as np
63
import pytest
74

8-
from scikeras.wrappers import KerasClassifier, KerasRegressor
5+
from scikeras.wrappers import KerasClassifier
96

10-
from .mlp_models import dynamic_classifier, dynamic_regressor
7+
from .mlp_models import dynamic_classifier
118

129

1310
def test_build_fn_deprecation():
@@ -17,72 +14,3 @@ def test_build_fn_deprecation():
1714
clf = KerasClassifier(build_fn=dynamic_classifier, model__hidden_layer_sizes=(100,))
1815
with pytest.warns(UserWarning, match="``build_fn`` will be renamed to ``model``"):
1916
clf.fit([[0], [1]], [0, 1])
20-
21-
22-
@pytest.mark.parametrize(
23-
"wrapper,builder",
24-
[(KerasClassifier, dynamic_classifier), (KerasRegressor, dynamic_regressor),],
25-
)
26-
def test_kwarg_deprecation(wrapper, builder):
27-
"""Test that SciKeras supports the **kwarg interface in fit and predict
28-
but warns the user about deprecation of this interface.
29-
"""
30-
original_batch_size = 128
31-
kwarg_batch_size = 90
32-
kwarg_epochs = (
33-
2 # epochs is a special case for fit since SciKeras also uses it internally
34-
)
35-
extra_kwargs = {"workers": 1} # chosen because it is not a SciKeras hardcoded param
36-
est = wrapper(
37-
model=builder,
38-
model__hidden_layer_sizes=(100,),
39-
warm_start=True, # for mocking to work properly
40-
batch_size=original_batch_size, # test that this is overridden by kwargs
41-
fit__batch_size=original_batch_size, # test that this is overridden by kwargs
42-
predict__batch_size=original_batch_size, # test that this is overridden by kwargs
43-
)
44-
X, y = np.random.random((100, 10)), np.random.randint(low=0, high=3, size=(100,))
45-
est.initialize(X, y)
46-
match_txt = r"``\*\*kwargs`` has been deprecated in SciKeras"
47-
# check fit
48-
with pytest.warns(UserWarning, match=match_txt):
49-
with mock.patch.object(
50-
est.model_, "fit", side_effect=est.model_.fit
51-
) as mock_fit:
52-
est.fit(
53-
X, y, batch_size=kwarg_batch_size, epochs=kwarg_epochs, **extra_kwargs
54-
)
55-
call_args = mock_fit.call_args_list
56-
assert len(call_args) == 1
57-
call_kwargs = call_args[0][1]
58-
assert "batch_size" in call_kwargs
59-
assert call_kwargs["batch_size"] == kwarg_batch_size
60-
assert call_kwargs["epochs"] == kwarg_epochs
61-
assert len(est.history_["loss"]) == kwarg_epochs
62-
# check predict
63-
with pytest.warns(UserWarning, match=match_txt):
64-
with mock.patch.object(
65-
est.model_, "predict", side_effect=est.model_.predict
66-
) as mock_predict:
67-
est.predict(X, batch_size=kwarg_batch_size, **extra_kwargs)
68-
call_args = mock_predict.call_args_list
69-
assert len(call_args) == 1
70-
call_kwargs = call_args[0][1]
71-
assert "batch_size" in call_kwargs
72-
assert call_kwargs["batch_size"] == kwarg_batch_size
73-
if isinstance(est, KerasClassifier):
74-
est.predict_proba(X, batch_size=kwarg_batch_size, **extra_kwargs)
75-
call_args = mock_predict.call_args_list
76-
assert len(call_args) == 2
77-
call_kwargs = call_args[1][1]
78-
assert "batch_size" in call_kwargs
79-
assert call_kwargs["batch_size"] == kwarg_batch_size
80-
# check that params were restored and extra_kwargs were not stored
81-
for param_name in ("batch_size", "fit__batch_size", "predict__batch_size"):
82-
assert getattr(est, param_name) == original_batch_size
83-
for k in extra_kwargs.keys():
84-
assert (
85-
not hasattr(est, k)
86-
or hasattr(est, "fit__" + k)
87-
or hasattr(est, "predict__" + k)
88-
)

tests/test_parameters.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
from unittest import mock
4+
35
import numpy as np
46
import pytest
57

@@ -198,7 +200,7 @@ class TestMetricsParam:
198200
@pytest.mark.parametrize("metric", ("accuracy", "sparse_categorical_accuracy"))
199201
def test_metrics(self, metric):
200202
"""Test the metrics param.
201-
203+
202204
Specifically test ``accuracy``, which Keras automatically
203205
matches to the loss function and hence should be passed through
204206
as a string and not as a retrieved function.
@@ -252,3 +254,69 @@ def test_class_weight_param():
252254
clf.partial_fit(X_train, y_train)
253255
y_pred = clf.predict(X_test)
254256
assert np.mean(y_pred == 0) > 0.95
257+
258+
259+
@pytest.mark.parametrize(
260+
"wrapper,builder",
261+
[(KerasClassifier, dynamic_classifier), (KerasRegressor, dynamic_regressor)],
262+
)
263+
def test_kwargs(wrapper, builder):
264+
"""Test that SciKeras supports the **kwarg interface in fit and predict."""
265+
original_batch_size = 128
266+
kwarg_batch_size = 90
267+
kwarg_epochs = (
268+
2 # epochs is a special case for fit since SciKeras also uses it internally
269+
)
270+
extra_kwargs = {"workers": 1} # chosen because it is not a SciKeras hardcoded param
271+
est = wrapper(
272+
model=builder,
273+
model__hidden_layer_sizes=(100,),
274+
warm_start=True, # for mocking to work properly
275+
batch_size=original_batch_size, # test that this is overridden by kwargs
276+
fit__batch_size=original_batch_size, # test that this is overridden by kwargs
277+
predict__batch_size=original_batch_size, # test that this is overridden by kwargs
278+
)
279+
X, y = np.random.random((100, 10)), np.random.randint(low=0, high=3, size=(100,))
280+
est.initialize(X, y)
281+
# check fit
282+
match = "estimator parameters as keyword arguments"
283+
with mock.patch.object(est.model_, "fit", side_effect=est.model_.fit) as mock_fit:
284+
with pytest.warns(UserWarning, match=match.format("fit")):
285+
est.fit(
286+
X, y, batch_size=kwarg_batch_size, epochs=kwarg_epochs, **extra_kwargs
287+
)
288+
call_args = mock_fit.call_args_list
289+
assert len(call_args) == 1
290+
call_kwargs = call_args[0][1]
291+
assert "batch_size" in call_kwargs
292+
assert call_kwargs["batch_size"] == kwarg_batch_size
293+
assert call_kwargs["epochs"] == kwarg_epochs
294+
assert len(est.history_["loss"]) == kwarg_epochs
295+
# check predict
296+
with mock.patch.object(
297+
est.model_, "predict", side_effect=est.model_.predict
298+
) as mock_predict:
299+
with pytest.warns(UserWarning, match=match.format("predict")):
300+
est.predict(X, batch_size=kwarg_batch_size, **extra_kwargs)
301+
call_args = mock_predict.call_args_list
302+
assert len(call_args) == 1
303+
call_kwargs = call_args[0][1]
304+
assert "batch_size" in call_kwargs
305+
assert call_kwargs["batch_size"] == kwarg_batch_size
306+
if isinstance(est, KerasClassifier):
307+
with pytest.warns(UserWarning, match=match.format("predict")):
308+
est.predict_proba(X, batch_size=kwarg_batch_size, **extra_kwargs)
309+
call_args = mock_predict.call_args_list
310+
assert len(call_args) == 2
311+
call_kwargs = call_args[1][1]
312+
assert "batch_size" in call_kwargs
313+
assert call_kwargs["batch_size"] == kwarg_batch_size
314+
# check that params were restored and extra_kwargs were not stored
315+
for param_name in ("batch_size", "fit__batch_size", "predict__batch_size"):
316+
assert getattr(est, param_name) == original_batch_size
317+
for k in extra_kwargs.keys():
318+
assert (
319+
not hasattr(est, k)
320+
or hasattr(est, "fit__" + k)
321+
or hasattr(est, "predict__" + k)
322+
)

0 commit comments

Comments
 (0)