Skip to content

Commit 3ad4f5b

Browse files
committed
Include csr_matrix in check for matrix.
1 parent ba0ac26 commit 3ad4f5b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

metalearners/metalearner.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
import pandas as pd
1313
import shap
14+
from scipy.sparse import csr_matrix
1415
from sklearn.metrics import get_scorer
1516
from sklearn.model_selection import KFold
1617
from typing_extensions import Self
@@ -124,7 +125,7 @@ def _filter_x_columns(X: Matrix, feature_set: Features) -> Matrix:
124125
return X
125126
if len(feature_set) == 0:
126127
return np.ones((safe_len(X), 1))
127-
if isinstance(X, np.ndarray):
128+
if isinstance(X, np.ndarray) or isinstance(X, csr_matrix):
128129
return X[:, np.array(feature_set)]
129130
if nw.dependencies.is_into_dataframe(X):
130131
X_nw = nw.from_native(X, eager_only=True)

0 commit comments

Comments
 (0)