13
13
import polars as pl
14
14
import scipy
15
15
from narwhals .dependencies import is_into_dataframe , is_into_series
16
+ from scipy .sparse import csr_matrix
16
17
from sklearn .base import is_classifier , is_regressor
17
18
from sklearn .ensemble import (
18
19
HistGradientBoostingClassifier ,
@@ -42,12 +43,13 @@ def index_matrix(matrix: Matrix, rows: Vector) -> Matrix:
42
43
if not hasattr (rows , "to_numpy" ):
43
44
raise ValueError ("rows couldn't be converted to numpy." )
44
45
rows = rows .to_numpy ()
46
+ if isinstance (matrix , np .ndarray ) or isinstance (matrix , csr_matrix ):
47
+ return matrix [rows , :]
45
48
if is_into_dataframe (matrix ):
46
- matrix_nw = nw .from_native (matrix , eager_only = True ) # type: ignore
49
+ matrix_nw = nw .from_native (matrix , eager_only = True )
47
50
if rows .dtype == "bool" :
48
51
return matrix_nw .filter (rows .tolist ()).to_native ()
49
52
return matrix_nw [rows .tolist (), :].to_native ()
50
- return matrix [rows , :]
51
53
52
54
53
55
def index_vector (vector : Vector , rows : Vector ) -> Vector :
@@ -56,12 +58,14 @@ def index_vector(vector: Vector, rows: Vector) -> Vector:
56
58
if not hasattr (rows , "to_numpy" ):
57
59
raise ValueError ("rows couldn't be converted to numpy." )
58
60
rows = rows .to_numpy ()
61
+ if isinstance (vector , np .ndarray ):
62
+ return vector [rows ]
59
63
if is_into_series (vector ):
60
- vector_nw = nw .from_native (vector , series_only = True , eager_only = True ) # type: ignore
64
+ vector_nw = nw .from_native (vector , series_only = True , eager_only = True )
61
65
if rows .dtype == "bool" :
62
66
return vector_nw .filter (rows ).to_native ()
63
- return vector_nw [rows ].to_native () # type: ignore
64
- return vector [ rows ]
67
+ return vector_nw [rows ].to_native ()
68
+ raise TypeError ( f"Encountered unexpected type of vector: { type ( vector ) } ." )
65
69
66
70
67
71
def are_pd_indices_equal (* args : pd .DataFrame | pd .Series ) -> bool :
0 commit comments