@@ -112,20 +112,21 @@ def index_matrix(matrix: Matrix, rows: Vector) -> Matrix:
112
112
matrix_nw = nw .from_native (matrix , eager_only = True )
113
113
114
114
if isinstance (rows , np .ndarray ):
115
- if rows .dtype == "bool" :
116
- return matrix_nw .filter (rows .tolist ()).to_native ()
117
- return matrix_nw [rows .tolist (), :].to_native ()
118
- if is_into_series (rows ):
115
+ rows_are_bool = rows .dtype == bool
116
+ rows_list = rows .tolist ()
117
+ elif is_into_series (rows ):
119
118
rows_nw = nw .from_native (rows , series_only = True , eager_only = True )
120
- if rows_nw .dtype == nw .Boolean :
121
- return matrix_nw .filter (rows_nw ).to_native ()
122
- return matrix_nw [rows_nw .to_list (), :].to_native ()
119
+ rows_are_bool = rows_nw .dtype == nw .Boolean
120
+ rows_list = rows_nw .to_list ()
121
+ else :
122
+ raise TypeError (f"Unexpected type { type (rows )} for rows." )
123
123
124
- raise ValueError (
125
- f"rows to index matrix with are of unexpected type: { type (rows )} "
126
- )
124
+ if rows_are_bool :
125
+ return matrix_nw .filter (rows_list ).to_native ()
126
+
127
+ return matrix_nw [rows_list , :].to_native ()
127
128
128
- raise ValueError (f"matrix to be indexed is of unexpected type: { type (matrix )} " )
129
+ raise TypeError (f"matrix to be indexed is of unexpected type: { type (matrix )} " )
129
130
130
131
131
132
def index_matrix3 (matrix : Matrix , rows : Vector ) -> Matrix :
0 commit comments