Skip to content

Commit e95897f

Browse files
committed
Fix test.
1 parent 0e595fc commit e95897f

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

metalearners/_utils.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,21 @@ def index_matrix(matrix: Matrix, rows: Vector) -> Matrix:
112112
matrix_nw = nw.from_native(matrix, eager_only=True)
113113

114114
if isinstance(rows, np.ndarray):
115-
if rows.dtype == "bool":
116-
return matrix_nw.filter(rows.tolist()).to_native()
117-
return matrix_nw[rows.tolist(), :].to_native()
118-
if is_into_series(rows):
115+
rows_are_bool = rows.dtype == bool
116+
rows_list = rows.tolist()
117+
elif is_into_series(rows):
119118
rows_nw = nw.from_native(rows, series_only=True, eager_only=True)
120-
if rows_nw.dtype == nw.Boolean:
121-
return matrix_nw.filter(rows_nw).to_native()
122-
return matrix_nw[rows_nw.to_list(), :].to_native()
119+
rows_are_bool = rows_nw.dtype == nw.Boolean
120+
rows_list = rows_nw.to_list()
121+
else:
122+
raise TypeError(f"Unexpected type {type(rows)} for rows.")
123123

124-
raise ValueError(
125-
f"rows to index matrix with are of unexpected type: {type(rows)}"
126-
)
124+
if rows_are_bool:
125+
return matrix_nw.filter(rows_list).to_native()
126+
127+
return matrix_nw[rows_list, :].to_native()
127128

128-
raise ValueError(f"matrix to be indexed is of unexpected type: {type(matrix)}")
129+
raise TypeError(f"matrix to be indexed is of unexpected type: {type(matrix)}")
129130

130131

131132
def index_matrix3(matrix: Matrix, rows: Vector) -> Matrix:

0 commit comments

Comments
 (0)