Skip to content

Commit ba0ac26

Browse files
committed
Enable mypy.
1 parent 47a7eab commit ba0ac26

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

metalearners/_utils.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import polars as pl
1414
import scipy
1515
from narwhals.dependencies import is_into_dataframe, is_into_series
16+
from scipy.sparse import csr_matrix
1617
from sklearn.base import is_classifier, is_regressor
1718
from sklearn.ensemble import (
1819
HistGradientBoostingClassifier,
@@ -42,12 +43,13 @@ def index_matrix(matrix: Matrix, rows: Vector) -> Matrix:
4243
if not hasattr(rows, "to_numpy"):
4344
raise ValueError("rows couldn't be converted to numpy.")
4445
rows = rows.to_numpy()
46+
if isinstance(matrix, np.ndarray) or isinstance(matrix, csr_matrix):
47+
return matrix[rows, :]
4548
if is_into_dataframe(matrix):
46-
matrix_nw = nw.from_native(matrix, eager_only=True) # type: ignore
49+
matrix_nw = nw.from_native(matrix, eager_only=True)
4750
if rows.dtype == "bool":
4851
return matrix_nw.filter(rows.tolist()).to_native()
4952
return matrix_nw[rows.tolist(), :].to_native()
50-
return matrix[rows, :]
5153

5254

5355
def index_vector(vector: Vector, rows: Vector) -> Vector:
@@ -56,12 +58,14 @@ def index_vector(vector: Vector, rows: Vector) -> Vector:
5658
if not hasattr(rows, "to_numpy"):
5759
raise ValueError("rows couldn't be converted to numpy.")
5860
rows = rows.to_numpy()
61+
if isinstance(vector, np.ndarray):
62+
return vector[rows]
5963
if is_into_series(vector):
60-
vector_nw = nw.from_native(vector, series_only=True, eager_only=True) # type: ignore
64+
vector_nw = nw.from_native(vector, series_only=True, eager_only=True)
6165
if rows.dtype == "bool":
6266
return vector_nw.filter(rows).to_native()
63-
return vector_nw[rows].to_native() # type: ignore
64-
return vector[rows]
67+
return vector_nw[rows].to_native()
68+
raise TypeError(f"Encountered unexpected type of vector: {type(vector)}.")
6569

6670

6771
def are_pd_indices_equal(*args: pd.DataFrame | pd.Series) -> bool:

metalearners/metalearner.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,16 @@ def _filter_x_columns(X: Matrix, feature_set: Features) -> Matrix:
124124
return X
125125
if len(feature_set) == 0:
126126
return np.ones((safe_len(X), 1))
127+
if isinstance(X, np.ndarray):
128+
return X[:, np.array(feature_set)]
127129
if nw.dependencies.is_into_dataframe(X):
128-
X_nw = nw.from_native(X, eager_only=True) # type: ignore
130+
X_nw = nw.from_native(X, eager_only=True)
129131
if all(map(lambda x: isinstance(x, int), feature_set)):
130-
return X_nw.select(nw.nth(feature_set)).to_native() # type: ignore
132+
return X_nw.select(nw.nth(feature_set)).to_native()
131133
if all(map(lambda x: isinstance(x, str), feature_set)):
132-
return X_nw.select(feature_set).to_native() # type: ignore
134+
return X_nw.select(feature_set).to_native()
133135
raise ValueError("features must either be all ints or all strings.")
134-
return X[:, np.array(feature_set)]
136+
raise TypeError(f"Unexpected type of matrix: {type(X)}.")
135137

136138

137139
def _validate_n_folds_synchronize(n_folds: dict[str, int]) -> None:

0 commit comments

Comments
 (0)