Skip to content

Commit fb7f49e

Browse files
committed
Fix sklearn regressor test?
1 parent c074416 commit fb7f49e

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

keras/wrappers/scikit_learn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,10 @@ def predict(self, x, **kwargs):
320320
Predictions.
321321
"""
322322
kwargs = self.filter_sk_params(Sequential.predict, kwargs)
323-
preds = self.model.predict(x, **kwargs)
324-
return np.squeeze(preds, axis=len(preds.shape) - 1)
323+
preds = np.array(self.model.predict(x, **kwargs))
324+
if preds.shape[-1] == 1:
325+
return np.squeeze(preds, axis=-1)
326+
return preds
325327

326328
def score(self, x, y, **kwargs):
327329
"""Returns the mean loss on the given test data and labels.

0 commit comments

Comments
 (0)