Skip to content

Commit 186f099

Browse files
authored
Rework handling of categories in S-Learner. (#105)
* Rework handling of categories in S-Learner. * Rename and add doc strings. * Fix test. * Move infer_native_namespace to _narwhals_utils module. * Add _narwhals_utils module. * Rename and adapt interface. * Remove dead code. * Test _narwhals_utils. * Test _np_to_dummies. * Test specific _append_treatment_ functions separately for easier navigation. * Fix indexing issue. * Get rid of branching. * Cast bool index to int. * Use constant instead of value. * Move _stringify_column_names to _narwhals_utils. * Add docstring. * Use schema at creation instead of rename after creation. * Add type hint for native_namespace parameter.
1 parent 2d87255 commit 186f099

File tree

6 files changed

+578
-142
lines changed

6 files changed

+578
-142
lines changed

metalearners/_narwhals_utils.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) QuantCo 2024-2025
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
from collections.abc import Sequence
5+
from types import ModuleType
6+
7+
import narwhals.stable.v1 as nw
8+
import numpy as np
9+
import pandas as pd
10+
import polars as pl
11+
from narwhals.dependencies import is_into_series
12+
13+
from metalearners._typing import Vector
14+
15+
16+
def nw_to_dummies(
17+
x: nw.Series, categories: Sequence, column_name: str, drop_first: bool = True
18+
) -> nw.DataFrame:
19+
"""Turn a vector into a matrix with dummies.
20+
21+
This operation is also referred to as one-hot-encoding.
22+
23+
``x`` is expected to have values which can be cast to integer.
24+
"""
25+
if len(categories) < 2:
26+
raise ValueError(
27+
"categories to be used for nw_to_dummies must have at least two "
28+
"distinct values."
29+
)
30+
31+
if set(categories) < set(x.unique()):
32+
raise ValueError("We observed a value which isn't part of the categories.")
33+
34+
relevant_categories = categories[1:] if drop_first else categories
35+
return x.to_frame().select(
36+
[
37+
(nw.col(column_name) == cat).cast(nw.Int8).name.suffix(f"_{cat}")
38+
for cat in relevant_categories
39+
]
40+
)
41+
42+
43+
def vector_to_nw(x: Vector, native_namespace: ModuleType | None = None) -> nw.Series:
44+
if isinstance(x, np.ndarray):
45+
if native_namespace is None:
46+
raise ValueError(
47+
"x is a numpy object but no native_namespace was provided to "
48+
"load it into narwhals."
49+
)
50+
# narwhals doesn't seem to like 1d numpy arrays. Therefore we first convert to
51+
# a 2d np array and then convert the narwhals DataFrame to a narwhals Series.
52+
return nw.from_numpy(x.reshape(-1, 1), native_namespace=native_namespace)[
53+
"column_0"
54+
]
55+
if is_into_series(x):
56+
return nw.from_native(x, series_only=True, eager_only=True)
57+
raise TypeError(f"Unexpected type {type(x)} for Vector.")
58+
59+
60+
def infer_native_namespace(df_nw: nw.DataFrame) -> ModuleType:
61+
if df_nw.implementation.name == "PANDAS":
62+
return pd
63+
if df_nw.implementation.name == "POLARS":
64+
return pl
65+
raise TypeError("Couldn't infer native_namespace of matrix.")
66+
67+
68+
def stringify_column_names(df_nw: nw.DataFrame) -> nw.DataFrame:
69+
return df_nw.rename({column: str(column) for column in df_nw.columns})

metalearners/_utils.py

+24-15
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import narwhals.stable.v1 as nw
1111
import numpy as np
1212
import pandas as pd
13+
import polars as pl
1314
import scipy
1415
from narwhals.dependencies import is_into_dataframe, is_into_series
1516
from sklearn.base import is_classifier, is_regressor
@@ -258,24 +259,32 @@ def check_probability(p: float, zero_included=False, one_included=False) -> None
258259
raise ValueError("Probability p must be less than or equal to 1.")
259260

260261

261-
def convert_treatment(treatment: Vector) -> np.ndarray:
262-
"""Convert to ``np.ndarray`` and adapt dtype, if necessary."""
263-
if isinstance(treatment, np.ndarray):
264-
new_treatment = treatment.copy()
265-
elif nw.dependencies.is_into_series(treatment):
266-
new_treatment = nw.from_native(
267-
treatment, series_only=True, eager_only=True
268-
).to_numpy() # type: ignore
269-
if new_treatment.dtype == bool:
270-
return new_treatment.astype(int)
271-
if new_treatment.dtype == float and all(x.is_integer() for x in new_treatment):
272-
return new_treatment.astype(int)
273-
274-
if not pd.api.types.is_integer_dtype(new_treatment):
262+
def adapt_treatment_dtypes(treatment: Vector) -> Vector:
263+
"""Cast the dtype of treatment to integer, if necessary.
264+
265+
Raises if not possible.
266+
"""
267+
if isinstance(treatment, pl.Series):
268+
dtype = treatment.dtype
269+
if dtype.is_integer():
270+
return treatment
271+
if dtype.to_python().__name__ == "bool":
272+
return treatment.cast(int)
273+
if dtype.is_float() and all(x.is_integer() for x in treatment):
274+
return treatment.cast(int)
275+
raise TypeError(
276+
"Treatment must be boolean, integer or float with integer values."
277+
)
278+
279+
if treatment.dtype == bool:
280+
return treatment.astype(int)
281+
if treatment.dtype == float and all(x.is_integer() for x in treatment):
282+
return treatment.astype(int)
283+
if not pd.api.types.is_integer_dtype(treatment):
275284
raise TypeError(
276285
"Treatment must be boolean, integer or float with integer values."
277286
)
278-
return new_treatment
287+
return treatment
279288

280289

281290
def supports_categoricals(model: _ScikitModel) -> bool:

0 commit comments

Comments
 (0)