1
1
"""Wrapper for using the Scikit-Learn API with Keras models.
2
2
"""
3
3
import inspect
4
- import os
5
4
import warnings
6
5
7
6
from collections import defaultdict
36
35
from scikeras .utils .transformers import ClassifierLabelEncoder , RegressorTargetEncoder
37
36
38
37
38
+ _kwarg_warn = """Passing estimator parameters as keyword arguments (aka as `**kwargs`) to `{0}` is not supported by the Scikit-Learn API, and will be removed in a future version of SciKeras.
39
+
40
+ To resolve this issue, either set these parameters in the constructor (e.g., `est = BaseWrapper(..., foo=bar)`) or via `set_params` (e.g., `est.set_params(foo=bar)`). The following parameters were passed to `{0}`:
41
+
42
+ {1}
43
+
44
+ More detail is available at https://www.adriangb.com/scikeras/migration.html#variable-keyword-arguments-in-fit-and-predict
45
+ """
46
+
47
+
39
48
class BaseWrapper (BaseEstimator ):
40
49
"""Implementation of the scikit-learn classifier API for Keras.
41
50
42
51
Below are a list of SciKeras specific parameters. For details on other parameters,
43
- please see the see the
44
- [tf.keras.Model documentation](https://www.tensorflow.org/api_docs/python/tf/keras/Model).
52
+ please see the see the `tf.keras.Model documentation <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`_.
45
53
46
54
Parameters
47
55
----------
@@ -696,20 +704,23 @@ def fit(self, X, y, sample_weight=None, **kwargs) -> "BaseWrapper":
696
704
If not provided, then each sample is given unit weight.
697
705
**kwargs : Dict[str, Any]
698
706
Extra arguments to route to ``Model.fit``.
699
- This functionality has been deprecated, and will be removed in SciKeras 1.0.0.
700
- These parameters can also be specified by prefixing `fit__` to a parameter at initialization;
701
- e.g, `BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)`.
707
+
708
+ Warnings
709
+ --------
710
+ Passing estimator parameters as keyword arguments (aka as ``**kwargs``) to ``fit`` is not supported by the Scikit-Learn API,
711
+ and will be removed in a future version of SciKeras.
712
+ These parameters can also be specified by prefixing ``fit__`` to a parameter at initialization
713
+ (``BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)``)
714
+ or by using ``set_params`` (``est.set_params(fit__batch_size=32, predict__batch_size=1000)``).
715
+
702
716
Returns
703
717
-------
704
718
BaseWrapper
705
- A reference to the instance that can be chain called
706
- (ex: instance.fit(X,y).transform(X) )
719
+ A reference to the instance that can be chain called (``est.fit(X,y).transform(X)``).
707
720
"""
708
- for k , v in kwargs .items ():
709
- warnings .warn (
710
- "``**kwargs`` has been deprecated in SciKeras 0.2.1 and support will be removed be 1.0.0."
711
- f" Instead, set fit arguments at initialization (i.e., ``BaseWrapper({ k } ={ v } )``)"
712
- )
721
+ if kwargs :
722
+ kwarg_list = "\n * " .join ([f"`{ k } ={ v } `" for k , v in kwargs .items ()])
723
+ warnings .warn (_kwarg_warn .format ("fit" , kwarg_list ))
713
724
714
725
# epochs via kwargs > fit__epochs > epochs
715
726
kwargs ["epochs" ] = kwargs .get (
@@ -886,11 +897,9 @@ def _predict_raw(self, X, **kwargs):
886
897
For classification, this corresponds to predict_proba.
887
898
For regression, this corresponds to predict.
888
899
"""
889
- for k , v in kwargs .items ():
890
- warnings .warn (
891
- "``**kwargs`` has been deprecated in SciKeras 0.2.1 and support will be removed be 1.0.0."
892
- f" Instead, set predict arguments at initialization (i.e., ``BaseWrapper({ k } ={ v } )``)"
893
- )
900
+ if kwargs :
901
+ kwarg_list = "\n * " .join ([f"`{ k } ={ v } `" for k , v in kwargs .items ()])
902
+ warnings .warn (_kwarg_warn .format ("predict" , kwarg_list ), stacklevel = 2 )
894
903
895
904
# check if fitted
896
905
if not self .initialized_ :
@@ -925,9 +934,14 @@ def predict(self, X, **kwargs):
925
934
and n_features is the number of features.
926
935
**kwargs : Dict[str, Any]
927
936
Extra arguments to route to ``Model.predict``.
928
- This functionality has been deprecated, and will be removed in SciKeras 1.0.0.
929
- These parameters can also be specified by prefixing `predict__` to a parameter at initialization;
930
- e.g, `BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)`.
937
+
938
+ Warnings
939
+ --------
940
+ Passing estimator parameters as keyword arguments (aka as ``**kwargs``) to ``predict`` is not supported by the Scikit-Learn API,
941
+ and will be removed in a future version of SciKeras.
942
+ These parameters can also be specified by prefixing ``predict__`` to a parameter at initialization
943
+ (``BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)``)
944
+ or by using ``set_params`` (``est.set_params(fit__batch_size=32, predict__batch_size=1000)``).
931
945
932
946
Returns
933
947
-------
@@ -1090,8 +1104,7 @@ class KerasClassifier(BaseWrapper):
1090
1104
"""Implementation of the scikit-learn classifier API for Keras.
1091
1105
1092
1106
Below are a list of SciKeras specific parameters. For details on other parameters,
1093
- please see the see the
1094
- [tf.keras.Model documentation](https://www.tensorflow.org/api_docs/python/tf/keras/Model).
1107
+ please see the see the `tf.keras.Model documentation <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`_.
1095
1108
1096
1109
Parameters
1097
1110
----------
@@ -1351,15 +1364,19 @@ def fit(self, X, y, sample_weight=None, **kwargs) -> "KerasClassifier":
1351
1364
If not provided, then each sample is given unit weight.
1352
1365
**kwargs : Dict[str, Any]
1353
1366
Extra arguments to route to ``Model.fit``.
1354
- This functionality has been deprecated, and will be removed in SciKeras 1.0.0.
1355
- These parameters can also be specified by prefixing `fit__` to a parameter at initialization;
1356
- e.g, `BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)`.
1367
+
1368
+ Warnings
1369
+ --------
1370
+ Passing estimator parameters as keyword arguments (aka as ``**kwargs``) to ``fit`` is not supported by the Scikit-Learn API,
1371
+ and will be removed in a future version of SciKeras.
1372
+ These parameters can also be specified by prefixing ``fit__`` to a parameter at initialization
1373
+ (``KerasClassifier(..., fit__batch_size=32, predict__batch_size=1000)``)
1374
+ or by using ``set_params`` (``est.set_params(fit__batch_size=32, predict__batch_size=1000)``).
1357
1375
1358
1376
Returns
1359
1377
-------
1360
1378
KerasClassifier
1361
- A reference to the instance that can be chain called
1362
- (ex: instance.fit(X,y).transform(X) )
1379
+ A reference to the instance that can be chain called (``est.fit(X,y).transform(X)``).
1363
1380
"""
1364
1381
self .classes_ = None
1365
1382
if self .class_weight is not None :
@@ -1415,9 +1432,14 @@ def predict_proba(self, X, **kwargs):
1415
1432
and n_features is the number of features.
1416
1433
**kwargs : Dict[str, Any]
1417
1434
Extra arguments to route to ``Model.predict``.
1418
- This functionality has been deprecated, and will be removed in SciKeras 1.0.0.
1419
- These parameters can also be specified by prefixing `predict__` to a parameter at initialization;
1420
- e.g, `BaseWrapper(..., fit__batch_size=32, predict__batch_size=1000)`.
1435
+
1436
+ Warnings
1437
+ --------
1438
+ Passing estimator parameters as keyword arguments (aka as ``**kwargs``) to ``predict_proba`` is not supported by the Scikit-Learn API,
1439
+ and will be removed in a future version of SciKeras.
1440
+ These parameters can also be specified by prefixing ``predict__`` to a parameter at initialization
1441
+ (``KerasClassifier(..., fit__batch_size=32, predict__batch_size=1000)``)
1442
+ or by using ``set_params`` (``est.set_params(fit__batch_size=32, predict__batch_size=1000)``).
1421
1443
1422
1444
Returns
1423
1445
-------
@@ -1441,8 +1463,7 @@ class KerasRegressor(BaseWrapper):
1441
1463
"""Implementation of the scikit-learn classifier API for Keras.
1442
1464
1443
1465
Below are a list of SciKeras specific parameters. For details on other parameters,
1444
- please see the see the
1445
- [tf.keras.Model documentation](https://www.tensorflow.org/api_docs/python/tf/keras/Model).
1466
+ please see the see the `tf.keras.Model documentation <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`_.
1446
1467
1447
1468
Parameters
1448
1469
----------
0 commit comments