Skip to content

Commit 88dc474

Browse files
authored
Merge pull request #6 from vecxoz/fix5
Fix #5
2 parents 94881f2 + 39fae27 commit 88dc474

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

vecstack/core.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from sklearn.metrics import mean_absolute_error
5050
from sklearn.metrics import accuracy_score
5151
from sklearn.metrics import log_loss
52+
from sklearn.utils.validation import check_X_y, check_array
5253

5354
#-------------------------------------------------------------------------------
5455
#-------------------------------------------------------------------------------
@@ -404,11 +405,20 @@ def your_metric(y_true, y_pred):
404405
# If empty <models> list
405406
if 0 == len(models):
406407
raise ValueError('List of models is empty')
407-
# Convert arrays to ndarrays
408+
# Check arrays
408409
# y_train and sample_weight must be 1d ndarrays (i.e. row, not column)
409-
X_train = np.array(X_train)
410-
y_train = np.array(y_train).ravel()
411-
X_test = np.array(X_test)
410+
X_train, y_train = check_X_y(X_train,
411+
y_train,
412+
accept_sparse=True, # allow all types of sparse
413+
force_all_finite=False, # allow nan and inf because
414+
# some models (xgboost) can handle
415+
multi_output=False) # do not allow several columns in y_train
416+
417+
if X_test is not None: # allow X_test to be None for mode='oof'
418+
X_test = check_array(X_test,
419+
accept_sparse=True, # allow all types of sparse
420+
force_all_finite=False) # allow nan and inf because
421+
# some models (xgboost) can handle
412422
if sample_weight is not None:
413423
sample_weight = np.array(sample_weight).ravel()
414424
# <regression>

0 commit comments

Comments
 (0)