Skip to content

Commit 649a422

Browse files
committed
ravel Y
1 parent 69e33a1 commit 649a422

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

slise/slise.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ def fit(
268268
if len(X.shape) == 1:
269269
X.shape = X.shape + (1,)
270270
assert X.shape[0] == Y.shape[0], "X and Y must have the same number of items!"
271+
if len(Y.shape) > 1:
272+
Y = Y.ravel()
273+
assert X.shape[0] == Y.shape[0], "Y cannot have multiple columns!"
271274
self._X = X
272275
self._Y = Y
273276
if weight is None:
@@ -624,6 +627,9 @@ def __init__(
624627
if len(X.shape) == 1:
625628
X.shape = X.shape + (1,)
626629
assert X.shape[0] == Y.shape[0], "X and Y must have the same number of items"
630+
if len(Y.shape) > 1:
631+
Y = Y.ravel()
632+
assert X.shape[0] == Y.shape[0], "Y cannot have multiple columns!"
627633
self._logit = logit
628634
self._normalise = normalise
629635
self._X = X

0 commit comments

Comments
 (0)