diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index 5ca51691..dbade3d6 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -503,7 +503,7 @@ def _fit_keras_model( # collect parameters params = self.get_params() fit_args = route_params(params, destination="fit", pass_filter=self._fit_kwargs) - fit_args["sample_weight"] = sample_weight + fit_args["sample_weight"] = [sample_weight] fit_args["epochs"] = initial_epoch + epochs fit_args["initial_epoch"] = initial_epoch fit_args.update(kwargs)