Skip to content

Commit c074416

Browse files
committed
Fix sklearn wrapper unit test in Python 3?
1 parent aa28910 commit c074416

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

keras/wrappers/scikit_learn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ def predict(self, x, **kwargs):
320320
Predictions.
321321
"""
322322
kwargs = self.filter_sk_params(Sequential.predict, kwargs)
323-
return np.squeeze(self.model.predict(x, **kwargs), axis=-1)
323+
preds = self.model.predict(x, **kwargs)
324+
return np.squeeze(preds, axis=len(preds.shape) - 1)
324325

325326
def score(self, x, y, **kwargs):
326327
"""Returns the mean loss on the given test data and labels.

0 commit comments

Comments
 (0)