|
10 | 10 | import narwhals.stable.v1 as nw
|
11 | 11 | import numpy as np
|
12 | 12 | import pandas as pd
|
| 13 | +import polars as pl |
13 | 14 | import scipy
|
14 | 15 | from narwhals.dependencies import is_into_dataframe, is_into_series
|
15 | 16 | from sklearn.base import is_classifier, is_regressor
|
@@ -258,24 +259,32 @@ def check_probability(p: float, zero_included=False, one_included=False) -> None
|
258 | 259 | raise ValueError("Probability p must be less than or equal to 1.")
|
259 | 260 |
|
260 | 261 |
|
261 |
| -def convert_treatment(treatment: Vector) -> np.ndarray: |
262 |
| - """Convert to ``np.ndarray`` and adapt dtype, if necessary.""" |
263 |
| - if isinstance(treatment, np.ndarray): |
264 |
| - new_treatment = treatment.copy() |
265 |
| - elif nw.dependencies.is_into_series(treatment): |
266 |
| - new_treatment = nw.from_native( |
267 |
| - treatment, series_only=True, eager_only=True |
268 |
| - ).to_numpy() # type: ignore |
269 |
| - if new_treatment.dtype == bool: |
270 |
| - return new_treatment.astype(int) |
271 |
| - if new_treatment.dtype == float and all(x.is_integer() for x in new_treatment): |
272 |
| - return new_treatment.astype(int) |
273 |
| - |
274 |
| - if not pd.api.types.is_integer_dtype(new_treatment): |
| 262 | +def adapt_treatment_dtypes(treatment: Vector) -> Vector: |
| 263 | + """Cast the dtype of treatment to integer, if necessary. |
| 264 | +
|
| 265 | + Raises if not possible. |
| 266 | + """ |
| 267 | + if isinstance(treatment, pl.Series): |
| 268 | + dtype = treatment.dtype |
| 269 | + if dtype.is_integer(): |
| 270 | + return treatment |
| 271 | + if dtype.to_python().__name__ == "bool": |
| 272 | + return treatment.cast(int) |
| 273 | + if dtype.is_float() and all(x.is_integer() for x in treatment): |
| 274 | + return treatment.cast(int) |
| 275 | + raise TypeError( |
| 276 | + "Treatment must be boolean, integer or float with integer values." |
| 277 | + ) |
| 278 | + |
| 279 | + if treatment.dtype == bool: |
| 280 | + return treatment.astype(int) |
| 281 | + if treatment.dtype == float and all(x.is_integer() for x in treatment): |
| 282 | + return treatment.astype(int) |
| 283 | + if not pd.api.types.is_integer_dtype(treatment): |
275 | 284 | raise TypeError(
|
276 | 285 | "Treatment must be boolean, integer or float with integer values."
|
277 | 286 | )
|
278 |
| - return new_treatment |
| 287 | + return treatment |
279 | 288 |
|
280 | 289 |
|
281 | 290 | def supports_categoricals(model: _ScikitModel) -> bool:
|
|
0 commit comments