diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index c771d394..624b6357 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -1,6 +1,6 @@ -=================================== -Advanced Usage of SciKeras Wrappers -=================================== +============== +Advanced Usage +============== Wrapper Classes --------------- @@ -128,6 +128,43 @@ offer an easy way to compile and tune compilation parameters. Examples: In all cases, returning an un-compiled model is equivalent to calling ``model.compile(**compile_kwargs)`` within ``model_build_fn``. +.. _loss-selection: + +Loss selection +++++++++++++++ + +If you do not explicitly define a loss, SciKeras attempts to find a loss +that matches the type of target (see :py:func:`sklearn.utils.multiclass.type_of_target`). + +For guidance selecting losses in Keras, please see Jason Brownlee's +excellent article `How to Choose Loss Functions When Training Deep Learning Neural Networks`_ +as well as `Keras Losses docs`_. + +Default losses are selected as follows: + +Classification +.............. + ++-----------+-----------+----------+---------------------------------+ +| # outputs | # classes | encoding | loss | ++===========+===========+==========+=================================+ +| 1 | <= 2 | any | binary crossentropy | ++-----------+-----------+----------+---------------------------------+ +| 1 | >=2 | labels | sparse categorical crossentropy | ++-----------+-----------+----------+---------------------------------+ +| 1 | >=2 | one-hot | unsupported | ++-----------+-----------+----------+---------------------------------+ +| > 1 | -- | -- | unsupported | ++-----------+-----------+----------+---------------------------------+ + +Note that SciKeras will not automatically infer the loss for one-hot encoded targets, +you would need to explicitly specify `loss="categorical_crossentropy"`. + +Regression +.......... + +Regression always defaults to mean squared error. +For multi-output models, Keras will use the sum of each output's loss. Arguments to ``model_build_fn`` ------------------------------- @@ -287,3 +324,7 @@ and :class:`scikeras.wrappers.KerasRegressor` respectively. To override these sc .. _Keras Callbacks docs: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks .. _Keras Metrics docs: https://www.tensorflow.org/api_docs/python/tf/keras/metrics + +.. _Keras Losses docs: https://www.tensorflow.org/api_docs/python/tf/keras/losses + +.. _How to Choose Loss Functions When Training Deep Learning Neural Networks: https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/ \ No newline at end of file diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 8555ea29..66ba19b1 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -38,16 +38,25 @@ it on a toy classification dataset using SciKeras model.add(keras.layers.Activation("softmax")) return model - clf = KerasClassifier( - get_model, - loss="sparse_categorical_crossentropy", - hidden_layer_dim=100, - ) + clf = KerasClassifier(get_model, hidden_layer_dim=100) clf.fit(X, y) y_proba = clf.predict_proba(X) +Note that SciKeras even chooses a loss function and compiles your model. +To override the default loss, simply specify a loss function: + +.. code-block:: diff + + -KerasClassifier(get_model, hidden_layer_dim=100) + +KerasClassifier(get_model, loss="categorical_crossentropy") + +In this case, you would need to specify the loss since SciKeras +will not default to categorical crossentropy, even for one-hot +encoded targets. +See :ref:`loss-selection` for more details. + In an sklearn Pipeline ---------------------- diff --git a/scikeras/utils/__init__.py b/scikeras/utils/__init__.py index 66650a0c..b8ddaee1 100644 --- a/scikeras/utils/__init__.py +++ b/scikeras/utils/__init__.py @@ -14,7 +14,7 @@ def _camel2snake(s: str) -> str: return "".join(["_" + c.lower() if c.isupper() else c for c in s]).lstrip("_") -def loss_name(loss: Union[str, Loss, Callable]) -> str: +def loss_name(loss: Union[str, Loss, Callable]) -> Union[str, None]: """Retrieves a loss's full name (eg: "mean_squared_error"). Parameters @@ -25,8 +25,9 @@ def loss_name(loss: Union[str, Loss, Callable]) -> str: Returns ------- - str - String name of the loss. + Union[str, None] + String name of the loss. String inputs that do not map to a known + Keras loss function return `None`. Notes ----- @@ -43,6 +44,8 @@ def loss_name(loss: Union[str, Loss, Callable]) -> str: 'binary_crossentropy' >>> loss_name(losses.binary_crossentropy) 'binary_crossentropy' + >>> loss_name("abcdefg") + None Raises ------ @@ -56,13 +59,17 @@ def loss_name(loss: Union[str, Loss, Callable]) -> str: "``loss`` must be a string, a function, an instance of ``tf.keras.losses.Loss``" " or a type inheriting from ``tf.keras.losses.Loss``" ) - fn_or_cls = keras_loss_get(loss) + try: + fn_or_cls = keras_loss_get(loss) + except ValueError: + # unknown loss + return None if isinstance(fn_or_cls, Loss): return _camel2snake(fn_or_cls.__class__.__name__) return fn_or_cls.__name__ -def metric_name(metric: Union[str, Metric, Callable]) -> str: +def metric_name(metric: Union[str, Metric, Callable]) -> Union[str, None]: """Retrieves a metric's full name (eg: "mean_squared_error"). Parameters @@ -73,8 +80,9 @@ def metric_name(metric: Union[str, Metric, Callable]) -> str: Returns ------- - str + Union[str, None] Full name for Keras metric. Ex: "mean_squared_error". + String inputs that do not map to a known Keras loss function return `None`. Notes ----- @@ -91,6 +99,8 @@ def metric_name(metric: Union[str, Metric, Callable]) -> str: 'BinaryCrossentropy' >>> metric_name(metrics.binary_crossentropy) 'binary_crossentropy' + >>> metric_name("abcdefg") + None Raises ------ @@ -106,7 +116,11 @@ def metric_name(metric: Union[str, Metric, Callable]) -> str: " ``tf.keras.metrics.Metric`` or a type inheriting from" " ``tf.keras.metrics.Metric``" ) - fn_or_cls = keras_metric_get(metric) + try: + fn_or_cls = keras_metric_get(metric) + except ValueError: + # unknown metric + return None if isinstance(fn_or_cls, Metric): return _camel2snake(fn_or_cls.__class__.__name__) return fn_or_cls.__name__ diff --git a/scikeras/utils/transformers.py b/scikeras/utils/transformers.py index 6016e394..c75376d1 100644 --- a/scikeras/utils/transformers.py +++ b/scikeras/utils/transformers.py @@ -154,7 +154,7 @@ def fit(self, y: np.ndarray) -> "ClassifierLabelEncoder": "multiclass-multioutput": FunctionTransformer(), "multilabel-indicator": FunctionTransformer(), } - if is_categorical_crossentropy(self.loss): + if target_type == "multiclass" and is_categorical_crossentropy(self.loss): encoders["multiclass"] = make_pipeline( TargetReshaper(), OneHotEncoder( diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index b36a5566..1adbeaea 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -345,27 +345,36 @@ def _get_compile_kwargs(self): compile_kwargs = route_params( init_params, destination="compile", pass_filter=self._compile_kwargs, ) - compile_kwargs["optimizer"] = _class_from_strings( - compile_kwargs["optimizer"], optimizers_module.get - ) + try: + compile_kwargs["optimizer"] = _class_from_strings( + compile_kwargs["optimizer"], optimizers_module.get + ) + except ValueError: + pass # unknown optimizer compile_kwargs["optimizer"] = unflatten_params( items=compile_kwargs["optimizer"], params=route_params( init_params, destination="optimizer", pass_filter=set(), strict=True, ), ) - compile_kwargs["loss"] = _class_from_strings( - compile_kwargs["loss"], losses_module.get - ) + try: + compile_kwargs["loss"] = _class_from_strings( + compile_kwargs["loss"], losses_module.get + ) + except ValueError: + pass # unknown loss compile_kwargs["loss"] = unflatten_params( items=compile_kwargs["loss"], params=route_params( init_params, destination="loss", pass_filter=set(), strict=False, ), ) - compile_kwargs["metrics"] = _class_from_strings( - compile_kwargs["metrics"], metrics_module.get - ) + try: + compile_kwargs["metrics"] = _class_from_strings( + compile_kwargs["metrics"], metrics_module.get + ) + except ValueError: + pass # unknown loss compile_kwargs["metrics"] = unflatten_params( items=compile_kwargs["metrics"], params=route_params( @@ -374,7 +383,7 @@ def _get_compile_kwargs(self): ) return compile_kwargs - def _build_keras_model(self): + def _build_keras_model(self) -> None: """Build the Keras model. This method will process all arguments and call the model building @@ -417,13 +426,17 @@ def _build_keras_model(self): model = final_build_fn(**build_params) else: model = final_build_fn(**build_params) - - return model + self.model_ = model + self._ensure_compiled_model() + return def _ensure_compiled_model(self) -> None: # compile model if user gave us an un-compiled model if not (hasattr(self.model_, "loss") and hasattr(self.model_, "optimizer")): - self.model_.compile(**self._get_compile_kwargs()) + self._compile_model(self._get_compile_kwargs()) + + def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: + self.model_.compile(**compile_kwargs) def _fit_keras_model( self, @@ -514,13 +527,12 @@ def _fit_keras_model( self.history_ = defaultdict(list) for key, val in hist.history.items(): - try: - key = metric_name(key) - except ValueError as e: + key_name = metric_name(key) + if key_name is not None: # Keras puts keys like "val_accuracy" and "loss" and - # "val_loss" in hist.history - if "Unknown metric function" not in str(e): - raise e + # "val_loss" in hist.history these will return + # None since they are not a real metric + key = key_name self.history_[key] += val def _check_model_compatibility(self, y: np.ndarray) -> None: @@ -790,7 +802,7 @@ def _initialize( feature_meta = getattr(self.feature_encoder, "get_metadata", dict)() vars(self).update(**feature_meta) - self.model_ = self._build_keras_model() + self._build_keras_model() return X, y @@ -855,7 +867,6 @@ def _fit( X, y = self._initialize(X, y) else: X, y = self._validate_data(X, y) - self._ensure_compiled_model() if sample_weight is not None: X, sample_weight = self._validate_sample_weight(X, sample_weight) @@ -1257,7 +1268,7 @@ def __init__( ] = "rmsprop", loss: Union[ Union[str, tf.keras.losses.Loss, Type[tf.keras.losses.Loss], Callable], None - ] = None, + ] = "auto", metrics: Union[ List[ Union[ @@ -1310,6 +1321,45 @@ def _type_of_target(self, y: np.ndarray) -> str: target_type = type_of_target(self.classes_) return target_type + def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: + if compile_kwargs["loss"] == "auto": + if len(self.model_.outputs) > 1: + raise ValueError( + 'Only single-output models are supported with `loss="auto"`' + ) + loss = None + hint = "" + if self.target_type_ == "binary": + if self.model_.outputs[0].shape[1] != 1: + raise ValueError( + "Binary classification expects a model with exactly 1 output unit." + ) + loss = "binary_crossentropy" + elif self.target_type_ == "multiclass": + if self.model_.outputs[0].shape[1] == 1: + raise ValueError( + "Multi-class targets require the model to have >1 output units." + ) + loss = "sparse_categorical_crossentropy" + elif self.target_type_ == "multilabel-indicator": + # one-hot encoded multiclass problem OR multilabel-indicator problem + hint = ( + "For this type of problem, the following may help:" + '\n - If there is only one class per example, loss="categorical_crossentropy" might be appropriate.' + '\n - If there are multiple classes per example, loss="binary_crossentropy" might be appropriate.' + ) + if loss is None: + msg = ( + f'`loss="auto"` is not supported for tasks of type {self.target_type_}.' + "\nInstead, you must compile the model yourself or explicitly pass a loss function, for example:" + '\n clf = KerasClassifier(..., loss="categorical_crossentropy")' + ) + if hint: + msg += f"\n\n{hint}" + raise NotImplementedError(msg) + compile_kwargs["loss"] = loss + self.model_.compile(**compile_kwargs) + @staticmethod def scorer(y_true, y_pred, **kwargs) -> float: """Scoring function for KerasClassifier. @@ -1611,6 +1661,73 @@ class KerasRegressor(BaseWrapper): **BaseWrapper._tags, } + def __init__( + self, + model: Union[None, Callable[..., tf.keras.Model], tf.keras.Model] = None, + *, + build_fn: Union[ + None, Callable[..., tf.keras.Model], tf.keras.Model + ] = None, # for backwards compatibility + warm_start: bool = False, + random_state: Union[int, np.random.RandomState, None] = None, + optimizer: Union[ + str, tf.keras.optimizers.Optimizer, Type[tf.keras.optimizers.Optimizer] + ] = "rmsprop", + loss: Union[ + Union[str, tf.keras.losses.Loss, Type[tf.keras.losses.Loss], Callable], None + ] = "auto", + metrics: Union[ + List[ + Union[ + str, + tf.keras.metrics.Metric, + Type[tf.keras.metrics.Metric], + Callable, + ] + ], + None, + ] = None, + batch_size: Union[int, None] = None, + validation_batch_size: Union[int, None] = None, + verbose: int = 1, + callbacks: Union[ + List[Union[tf.keras.callbacks.Callback, Type[tf.keras.callbacks.Callback]]], + None, + ] = None, + validation_split: float = 0.0, + shuffle: bool = True, + run_eagerly: bool = False, + epochs: int = 1, + **kwargs, + ): + super().__init__( + model=model, + build_fn=build_fn, + warm_start=warm_start, + random_state=random_state, + optimizer=optimizer, + loss=loss, + metrics=metrics, + batch_size=batch_size, + validation_batch_size=validation_batch_size, + verbose=verbose, + callbacks=callbacks, + validation_split=validation_split, + shuffle=shuffle, + run_eagerly=run_eagerly, + epochs=epochs, + **kwargs, + ) + + def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: + if compile_kwargs["loss"] == "auto": + if len(self.model_.outputs) > 1: + raise ValueError( + 'Only single-output models are supported with `loss="auto"`' + ) + compile_kwargs["loss"] = "mean_squared_error" + self.model_.compile(**compile_kwargs) + @staticmethod def scorer(y_true, y_pred, **kwargs) -> float: """Scoring function for KerasRegressor. diff --git a/tests/mlp_models.py b/tests/mlp_models.py index 4ffa845a..fece3dd7 100644 --- a/tests/mlp_models.py +++ b/tests/mlp_models.py @@ -26,11 +26,13 @@ def dynamic_classifier( for layer_size in hidden_layer_sizes: hidden = Dense(layer_size, activation="relu")(hidden) + loss = None if compile_kwargs["loss"] == "auto" else compile_kwargs["loss"] + if target_type_ == "binary": - compile_kwargs["loss"] = compile_kwargs["loss"] or "binary_crossentropy" + compile_kwargs["loss"] = loss or "binary_crossentropy" out = [Dense(1, activation="sigmoid")(hidden)] elif target_type_ == "multilabel-indicator": - compile_kwargs["loss"] = compile_kwargs["loss"] or "binary_crossentropy" + compile_kwargs["loss"] = loss or "binary_crossentropy" if isinstance(n_classes_, list): out = [ Dense(1, activation="sigmoid")(hidden) @@ -39,13 +41,11 @@ def dynamic_classifier( else: out = Dense(n_classes_, activation="softmax")(hidden) elif target_type_ == "multiclass-multioutput": - compile_kwargs["loss"] = compile_kwargs["loss"] or "binary_crossentropy" + compile_kwargs["loss"] = loss or "binary_crossentropy" out = [Dense(n, activation="softmax")(hidden) for n in n_classes_] else: # multiclass - compile_kwargs["loss"] = ( - compile_kwargs["loss"] or "sparse_categorical_crossentropy" - ) + compile_kwargs["loss"] = loss or "sparse_categorical_crossentropy" out = [Dense(n_classes_, activation="softmax")(hidden)] model = Model(inp, out) @@ -60,13 +60,13 @@ def dynamic_regressor( meta: Optional[Dict[str, Any]] = None, compile_kwargs: Optional[Dict[str, Any]] = None, ) -> Model: - """Creates a basic MLP regressor dynamically. - """ + """Creates a basic MLP regressor dynamically.""" # get parameters n_features_in_ = meta["n_features_in_"] n_outputs_ = meta["n_outputs_"] - compile_kwargs["loss"] = compile_kwargs["loss"] or "mse" + if compile_kwargs["loss"] == "auto": + compile_kwargs["loss"] = "mean_squared_error" inp = Input(shape=(n_features_in_,)) diff --git a/tests/test_api.py b/tests/test_api.py index 0cf2a4d4..67daedcd 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -327,7 +327,7 @@ def test_basic(self, config): keras_model = build_fn( meta=meta, hidden_layer_sizes=(100,), - compile_kwargs={"optimizer": "adam", "loss": None, "metrics": None,}, + compile_kwargs={"optimizer": "adam", "loss": "auto", "metrics": None,}, ) else: meta = { @@ -337,7 +337,7 @@ def test_basic(self, config): keras_model = build_fn( meta=meta, hidden_layer_sizes=(100,), - compile_kwargs={"optimizer": "adam", "loss": None, "metrics": None,}, + compile_kwargs={"optimizer": "adam", "loss": "auto", "metrics": None,}, ) estimator = model(model=keras_model) @@ -363,7 +363,7 @@ def test_ensemble(self, config): keras_model = build_fn( meta=meta, hidden_layer_sizes=(100,), - compile_kwargs={"optimizer": "adam", "loss": None, "metrics": None,}, + compile_kwargs={"optimizer": "adam", "loss": "auto", "metrics": None,}, ) else: meta = { @@ -373,7 +373,7 @@ def test_ensemble(self, config): keras_model = build_fn( meta=meta, hidden_layer_sizes=(100,), - compile_kwargs={"optimizer": "adam", "loss": None, "metrics": None,}, + compile_kwargs={"optimizer": "adam", "loss": "auto", "metrics": None,}, ) base_estimator = model(model=keras_model) @@ -622,8 +622,7 @@ def force_compile_shorthand(hidden_layer_sizes, meta, compile_kwargs, params): class TestHistory: def test_history(self): - """Test that history_'s keys are strings and values are lists. - """ + """Test that history_'s keys are strings and values are lists.""" data = load_boston() X, y = data.data[:100], data.target[:100] estimator = KerasRegressor( @@ -756,8 +755,7 @@ def _keras_build_fn(self): class TestInitialize: - """Test the ``initialize`` method. - """ + """Test the ``initialize`` method.""" @pytest.mark.parametrize("wrapper", [KerasClassifier, KerasRegressor]) def test_prebuilt_model(self, wrapper): diff --git a/tests/test_loss_auto.py b/tests/test_loss_auto.py new file mode 100644 index 00000000..35f12efe --- /dev/null +++ b/tests/test_loss_auto.py @@ -0,0 +1,233 @@ +import numpy as np +import pytest +import tensorflow as tf + +from sklearn.preprocessing import FunctionTransformer, OneHotEncoder + +from scikeras.utils import loss_name +from scikeras.wrappers import KerasClassifier, KerasRegressor + + +N_CLASSES = 4 +FEATURES = 8 +n_eg = 100 +X = np.random.uniform(size=(n_eg, FEATURES)).astype("float32") + + +def shallow_net(outputs=None, loss=None, compile=False): + model = tf.keras.Sequential() + model.add(tf.keras.layers.Input(shape=(FEATURES,))) + if outputs is None: + model.add(tf.keras.layers.Dense(N_CLASSES)) + else: + model.add(tf.keras.layers.Dense(outputs)) + + if compile: + model.compile(loss=loss) + + return model + + +@pytest.mark.parametrize( + "loss", + [ + "binary_crossentropy", + "categorical_crossentropy", + "sparse_categorical_crossentropy", + "poisson", + "kl_divergence", + "hinge", + "categorical_hinge", + "squared_hinge", + ], +) +def test_user_compiled(loss): + """Test to make sure that user compiled classification models work with all + classification losses. + """ + model__outputs = 1 if "binary" in loss else None + if loss == "binary_crosentropy": + y = np.random.randint(0, 2, size=(n_eg,)) + elif loss == "categorical_crossentropy": + # SciKeras does not auto one-hot encode unless + # loss="categorical_crossentropy" is explictily passed to the constructor + y = np.random.randint(0, N_CLASSES, size=(n_eg, 1)) + y = OneHotEncoder(sparse=False).fit_transform(y) + else: + y = np.random.randint(0, N_CLASSES, size=(n_eg,)) + est = KerasClassifier( + shallow_net, + model__compile=True, + model__loss=loss, + model__outputs=model__outputs, + ) + est.partial_fit(X, y) + + assert est.model_.loss == loss # not est.model_.loss.__name__ b/c user compiled + assert est.current_epoch == 1 + + +class NoEncoderClf(KerasClassifier): + """A classifier overriding default target encoding. + This simulates a user implementing custom encoding logic in + target_encoder to support multiclass-multioutput or + multilabel-indicator, which by default would raise an error. + """ + + @property + def target_encoder(self): + return FunctionTransformer() + + +@pytest.mark.parametrize( + "use_case,wrapper_cls", + [ + ("multilabel-indicator", NoEncoderClf), + ("multiclass-multioutput", NoEncoderClf), + ("classification_w_onehot_targets", KerasClassifier), + ], +) +def test_classifier_unsupported_multi_output_tasks(use_case, wrapper_cls): + """Test for an appropriate error for tasks that are not supported + by `loss="auto"`. + """ + extra = "" + fix_loss = None + if use_case == "multiclass-multioutput": + y1 = np.random.randint(0, 1, size=len(X)) + y2 = np.random.randint(0, 2, size=len(X)) + y = np.column_stack([y1, y2]) + elif use_case == "multilabel-indicator": + y1 = np.random.randint(0, 1, size=len(X)) + y = np.column_stack([y1, y1]) + y[0, :] = 1 + fix_loss = "binary_crossentropy" + extra = f'loss="{fix_loss}" might be appropriate' + elif use_case == "classification_w_onehot_targets": + y = np.random.choice(N_CLASSES, size=len(X)).astype(int) + y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1)) + fix_loss = "categorical_crossentropy" + extra = f'loss="{fix_loss}" might be appropriate' + match = '`loss="auto"` is not supported for tasks of type' + if extra: + match += f"(.|\n)+{extra}" + with pytest.raises(NotImplementedError, match=match): + wrapper_cls(shallow_net, model__compile=False).initialize(X, y) + if fix_loss: + wrapper_cls(shallow_net, model__compile=False, loss=fix_loss).initialize(X, y) + + +@pytest.mark.parametrize( + "use_case", + [ + "binary_classification", + "binary_classification_w_one_class", + "classification_w_1d_targets", + ], +) +def test_classifier_default_loss_only_model_specified(use_case): + """Test that KerasClassifier will auto-determine a loss function + when only the model is specified. + """ + + model__outputs = 1 if "binary" in use_case else None + if use_case == "binary_classification": + exp_loss = "binary_crossentropy" + y = np.random.choice(2, size=len(X)).astype(int) + elif use_case == "binary_classification_w_one_class": + exp_loss = "binary_crossentropy" + y = np.zeros(len(X)) + elif use_case == "classification_w_1d_targets": + exp_loss = "sparse_categorical_crossentropy" + y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int) + + est = KerasClassifier(model=shallow_net, model__outputs=model__outputs) + + est.fit(X, y=y) + assert loss_name(est.model_.loss) == exp_loss + assert est.loss == "auto" + + +@pytest.mark.parametrize("use_case", ["single_output", "multi_output"]) +def test_regressor_default_loss_only_model_specified(use_case): + """Test that KerasRegressor will auto-determine a loss function + when only the model is specified. + """ + y = np.random.uniform(size=len(X)) + if use_case == "multi_output": + y = np.column_stack([y, y]) + est = KerasRegressor( + model=shallow_net, model__outputs=1 if "single" in use_case else 2 + ) + est.fit(X, y) + assert est.loss == "auto" + assert loss_name(est.model_.loss) == "mean_squared_error" + + +class Multi: + """Mixin for a simple 2 output model + """ + + def _keras_build_fn(self, compile): + inp = tf.keras.layers.Input(shape=(FEATURES,)) + out1 = tf.keras.layers.Dense(1)(inp) + out2 = tf.keras.layers.Dense(1)(inp) + model = tf.keras.Model(inp, [out1, out2]) + if compile: + model.compile(loss="mse") + return model + + @property + def target_encoder(self): + return FunctionTransformer(lambda x: [x[:, 0], x[:, 1]]) + + +class RegMulti(Multi, KerasRegressor): + pass + + +class ClfMulti(Multi, KerasClassifier): + pass + + +@pytest.mark.parametrize("user_compiled", [True, False]) +@pytest.mark.parametrize("est_cls", [RegMulti, ClfMulti]) +def test_multi_output_support(user_compiled, est_cls): + """Test that `loss="auto"` does not support SciKeras + compiling for multi-output models but allows user-compiled models. + """ + y = np.random.randint(0, 1, size=len(X)) + y = np.column_stack([y, y]) + est = est_cls(model__compile=user_compiled) + if user_compiled: + est.fit(X, y) + else: + with pytest.raises( + ValueError, + match='Only single-output models are supported with `loss="auto"`', + ): + est.fit(X, y) + + +def test_multiclass_single_output_unit(): + """Test that multiclass targets requires > 1 output units. + """ + est = KerasClassifier(model=shallow_net, model__outputs=1) + y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int) + with pytest.raises( + ValueError, + match="Multi-class targets require the model to have >1 output units", + ): + est.fit(X, y) + + +def test_binary_multiple_output_units(): + """Test that binary targets requires exactly 1 output unit. + """ + est = KerasClassifier(model=shallow_net, model__outputs=2) + y = np.random.choice(2, size=len(X)).astype(int) + with pytest.raises( + ValueError, + match="Binary classification expects a model with exactly 1 output unit", + ): + est.fit(X, y) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8fc919f1..7e63fd45 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -57,9 +57,8 @@ def test_loss_types(loss): loss_name(loss) -def test_unknown_loss_raises(): - with pytest.raises(ValueError, match="Unknown loss function"): - loss_name("unknown_loss") +def test_unknown_loss(): + assert loss_name("unknown_loss") == None @pytest.mark.parametrize("obj", [object(), object, list()]) @@ -69,8 +68,7 @@ def test_metric_types(obj): def test_unknown_metric(): - with pytest.raises(ValueError, match="Unknown metric function"): - metric_name("unknown_metric") + assert metric_name("unknown_metric") == None @pytest.mark.parametrize("metric", [CustomMetric, CustomMetric()])