Skip to content

Commit fbb8e57

Browse files
FBruzzesiMarcoGorellianopsyDeaMariaLeon
authored
feat: Narwhals for dataframe-agnostic codebase (#671)
* placeholder to develop narwhals features * feat: make `ColumnDropper` dataframe-agnostic (#655) * feat: make ColumnDropped dataframe-agnostic * use narwhals[polars] in pyproject.toml, link to list of supported libraries * note that narwhals is used for cross-dataframe support * test refactor * docstrings --------- Co-authored-by: FBruzzesi <[email protected]> * feat: make ColumnSelector dataframe-agnostic (#659) * columnselector with test rufformatted * adding whitespace * fixed the fit and transform * removed intendation in examples * font:false * feat: make `add_lags` dataframe-agnostic (#661) * make add_lags dataframe-agnostic * try getting tests to run? * patch: cvxpy 1.5.0 support (#663) --------- Co-authored-by: Francesco Bruzzesi <[email protected]> * Make `RegressionOutlier` dataframe-agnostic (#665) * make regression outlier df-agnostic * need to use eager-only for this one * pass native to check_array * remove cudf, link to check_X_y * feat: Make InformationFilter dataframe-agnostic * Make Timegapsplit dataframe-agnostic (#668) * make timegapsplit dataframe-agnostic * actually, include cuDF * feat: make FairClassifier data-agnostic (#669) * start all over * fixture working * wip * passing tests - again * pre-commit complaining * changed fixture on test_demographic_parity * feat: Make PandasTypeSelector selector dataframe-agnostic (#670) * make pandas dtype selector df-agnostic * bump version * 3.8 compat * Update sklego/preprocessing/pandastransformers.py Co-authored-by: Francesco Bruzzesi <[email protected]> * fixup pyproject.toml * unify (and test!) error message * deprecate * update readme * undo contribution.md change --------- Co-authored-by: Francesco Bruzzesi <[email protected]> * format typeselector and bump version * feat: Make grouped and hierarchical dataframe-agnostic (#667) * feat: make grouped and hierarchical dataframe-agnostic * add pyarrow * narwhals grouped_transformer * grouped transformer eureka * hierarchical narwhalified * so close but so far * return series instead of DataFrame for y * grouped WIP * merge branch and fix grouped * future annotations * format * handling negative indices * solve conflicts * hacking C * fairness: change C values in tests --------- Co-authored-by: Marco Edward Gorelli <[email protected]> Co-authored-by: Magdalena Anopsy <[email protected]> Co-authored-by: Dea María Léon <[email protected]>
1 parent 6a9654f commit fbb8e57

35 files changed

+1158
-736
lines changed

.github/workflows/test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ on:
44
pull_request:
55
branches:
66
- main
7+
- narwhals-development
78

89
jobs:
910
test:

docs/api/preprocessing.md

+5
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,8 @@
6464
options:
6565
show_root_full_path: true
6666
show_root_heading: true
67+
68+
:::sklego.preprocessing.pandastransformers.TypeSelector
69+
options:
70+
show_root_full_path: true
71+
show_root_heading: true

docs/contribution.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ When a new feature is introduced, it should be documented, and typically there a
174174
- [x] A user guide in the `docs/user-guide/` folder.
175175
- [x] A python script in the `docs/_scripts/` folder to generate plots and code snippets (see [next section](#working-with-pymdown-snippets-extension))
176176
- [x] Relevant static files, such as images, plots, tables and html's, should be saved in the `docs/_static/` folder.
177-
- [x] Edit the `mkdocs.yaml` file to include the new pages in the navigation.
177+
- [x] Edit the `mkdocs.yaml` file to include the new pages in the navigation.
178178

179179
### Working with pymdown snippets extension
180180

docs/this.md

+12-2
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,20 @@ not everything needs to be built, not everything needs to be explored.
3737
Change everything and you'll soon be a jerk,
3838
you may invent a new tool, not a way to work.
3939
Some problems cannot be solved in a single day,
40-
but if you ignore them, they sometimes go away.
40+
but if you can ignore them, they sometimes go away.
41+
42+
So as we forge ahead, let's remember the creed,
43+
simplicity over complexity, our library's seed.
44+
In the maze of features, let's not lose sight,
45+
of the end goal in mind shining bright.
46+
47+
With each new feature, a temptation to craft,
48+
but elegance is found in what we choose to subtract.
49+
For every line of code, let's ask ourselves twice,
50+
does it add clarity or is it a vice?
4151

4252
There's a lot of power in simplicity,
43-
it keeps you approach strong,
53+
it keeps the approach strong,
4454
if you understand the solution better than the problem,
4555
you're doing it wrong.
4656
```

mkdocs.yaml

+1-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ theme:
2121
name: material
2222
logo: _static/logo.png
2323
favicon: _static/logo.png
24-
font:
25-
text: Ubuntu
26-
code: Ubuntu Mono
24+
font: false
2725
highlightjs: true
2826
hljs_languages:
2927
- bash

pyproject.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "scikit-lego"
7-
version = "0.8.2"
7+
version = "0.9.0"
88
description="A collection of lego bricks for scikit-learn pipelines"
99

1010
license = {file = "LICENSE"}
@@ -20,6 +20,7 @@ maintainers = [
2020
]
2121

2222
dependencies = [
23+
"narwhals>=0.8.13",
2324
"pandas>=1.1.5",
2425
"scikit-learn>=1.0",
2526
"importlib-metadata >= 1.0; python_version < '3.8'",
@@ -61,6 +62,8 @@ docs = [
6162
]
6263

6364
test = [
65+
"narwhals[polars]",
66+
"pyarrow",
6467
"pytest>=6.2.5",
6568
"pytest-xdist>=1.34.0",
6669
"pytest-cov>=2.6.1",
@@ -111,4 +114,3 @@ markers = [
111114
"formulaic: tests that require formulaic (deselect with '-m \"not formulaic\"')",
112115
"umap: tests that require umap (deselect with '-m \"not umap\"')"
113116
]
114-

readme.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ Here's a list of features that this library currently offers:
120120
- `sklego.preprocessing.InformationFilter` transformer that can de-correlate features
121121
- `sklego.preprocessing.IdentityTransformer` returns the same data, allows for concatenating pipelines
122122
- `sklego.preprocessing.OrthogonalTransformer` makes all features linearly independent
123-
- `sklego.preprocessing.PandasTypeSelector` selects columns based on pandas type
123+
- `sklego.preprocessing.TypeSelector` selects columns based on type
124124
- `sklego.preprocessing.RandomAdder` adds randomness in training
125125
- `sklego.preprocessing.RepeatingBasisFunction` repeating feature engineering, useful for timeseries
126126
- `sklego.preprocessing.DictMapper` assign numeric values on categorical columns

sklego/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def transform_train(self, X, y=None):
5858
"""
5959

6060
_HASHERS = {
61-
pd.DataFrame: lambda X: hashlib.sha256(pd.util.hash_pandas_object(X, index=True).values).hexdigest(),
61+
pd.DataFrame: lambda X: hashlib.sha256(pd.util.hash_pandas_object(X, index=True).to_numpy()).hexdigest(),
6262
np.ndarray: lambda X: hash(X.data.tobytes()),
6363
np.memmap: lambda X: hash(X.data.tobytes()),
6464
}

sklego/datasets.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def load_penguins(return_X_y=False, as_frame=False):
112112
"body_mass_g",
113113
"sex",
114114
]
115-
].values,
116-
df["species"].values,
115+
].to_numpy(),
116+
df["species"].to_numpy(),
117117
)
118118
if return_X_y:
119119
return X, y
@@ -162,8 +162,8 @@ def load_arrests(return_X_y=False, as_frame=False):
162162
if as_frame:
163163
return df
164164
X, y = (
165-
df[["colour", "year", "age", "sex", "employed", "citizen", "checks"]].values,
166-
df["released"].values,
165+
df[["colour", "year", "age", "sex", "employed", "citizen", "checks"]].to_numpy(),
166+
df["released"].to_numpy(),
167167
)
168168
if return_X_y:
169169
return X, y
@@ -208,7 +208,7 @@ def load_chicken(return_X_y=False, as_frame=False):
208208
df = pd.read_csv(filepath)
209209
if as_frame:
210210
return df
211-
X, y = df[["time", "diet", "chick"]].values, df["weight"].values
211+
X, y = df[["time", "diet", "chick"]].to_numpy(), df["weight"].to_numpy()
212212
if return_X_y:
213213
return X, y
214214
return {"data": X, "target": y}
@@ -265,8 +265,8 @@ def load_abalone(return_X_y=False, as_frame=False):
265265
"shell_weight",
266266
"rings",
267267
]
268-
].values
269-
y = df["sex"].values
268+
].to_numpy()
269+
y = df["sex"].to_numpy()
270270
if return_X_y:
271271
return X, y
272272
return {"data": X, "target": y}
@@ -304,8 +304,8 @@ def load_heroes(return_X_y=False, as_frame=False):
304304
df = pd.read_csv(filepath)
305305
if as_frame:
306306
return df
307-
X = df[["health", "attack"]].values
308-
y = df["attack_type"].values
307+
X = df[["health", "attack"]].to_numpy()
308+
y = df["attack_type"].to_numpy()
309309
if return_X_y:
310310
return X, y
311311
return {"data": X, "target": y}
@@ -377,8 +377,8 @@ def load_hearts(return_X_y=False, as_frame=False):
377377
"ca",
378378
"thal",
379379
]
380-
].values
381-
y = df["target"].values
380+
].to_numpy()
381+
y = df["target"].to_numpy()
382382
if return_X_y:
383383
return X, y
384384
return {"data": X, "target": y}

sklego/linear_model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from inspect import signature
1010
from warnings import warn
1111

12+
import narwhals as nw
1213
import numpy as np
13-
import pandas as pd
1414
from scipy.optimize import minimize
1515
from scipy.special._ufuncs import expit
1616
from sklearn.base import BaseEstimator, RegressorMixin
@@ -493,8 +493,8 @@ def fit(self, X, y):
493493
raise ValueError(f"penalty should be either 'l1' or 'none', got {self.penalty}")
494494

495495
self.sensitive_col_idx_ = self.sensitive_cols
496-
497-
if isinstance(X, pd.DataFrame):
496+
X = nw.from_native(X, eager_only=True, strict=False)
497+
if isinstance(X, nw.DataFrame):
498498
self.sensitive_col_idx_ = [i for i, name in enumerate(X.columns) if name in self.sensitive_cols]
499499
X, y = check_X_y(X, y, accept_large_sparse=False)
500500
sensitive = X[:, self.sensitive_col_idx_]

sklego/meta/_grouped_utils.py

+41-37
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,59 @@
1-
from typing import Tuple
1+
from __future__ import annotations
22

3-
import numpy as np
3+
from typing import List
4+
5+
import narwhals as nw
46
import pandas as pd
57
from scipy.sparse import issparse
68
from sklearn.utils import check_array
79
from sklearn.utils.validation import _ensure_no_complex_data
810

911

10-
def _split_groups_and_values(
11-
X, groups, name="", min_value_cols=1, check_X=True, **kwargs
12-
) -> Tuple[pd.DataFrame, np.ndarray]:
13-
_data_format_checks(X, name=name)
14-
check_array(X, ensure_min_features=min_value_cols, dtype=None, force_all_finite=False)
12+
def parse_X_y(X, y, groups, check_X=True, **kwargs) -> nw.DataFrame:
13+
"""Converts X, y to narwhals dataframe.
1514
16-
try:
17-
if isinstance(X, pd.DataFrame):
18-
X_group = X.loc[:, groups]
19-
X_value = X.drop(columns=groups).values
20-
else:
21-
X = np.asarray(X) # deals with `_NotAnArray` case
22-
X_group = pd.DataFrame(X[:, groups])
23-
pos_indexes = range(X.shape[1])
24-
X_value = np.delete(X, [pos_indexes[g] for g in groups], axis=1)
25-
except (KeyError, IndexError):
26-
raise ValueError(f"Could not drop groups {groups} from columns of X")
15+
If it is not a supported dataframe, it uses pandas constructor as a fallback.
2716
28-
X_group = _check_grouping_columns(X_group, **kwargs)
17+
Additionally, data checks are performed.
18+
"""
19+
# Check raw X
20+
_data_format_checks(X)
2921

30-
if check_X:
31-
X_value = check_array(X_value, **kwargs)
22+
# Convert X to Narwhals frame
23+
X = nw.from_native(X, strict=False, eager_only=True)
24+
if not isinstance(X, nw.DataFrame):
25+
X = nw.from_native(pd.DataFrame(X))
3226

33-
return X_group, X_value
27+
# Check groups and feaures values
28+
if groups is not None:
29+
_validate_groups_values(X, groups)
3430

31+
if check_X:
32+
check_array(X.drop(groups), **kwargs)
3533

36-
def _data_format_checks(X, name):
37-
_ensure_no_complex_data(X)
34+
# Convert y and assign it to the frame
35+
n_samples = X.shape[0]
36+
native_space = nw.get_native_namespace(X)
37+
38+
y_native = native_space.Series([None] * n_samples) if y is None else native_space.Series(y)
39+
return X.with_columns(__sklego_target__=nw.from_native(y_native, allow_series=True))
3840

39-
if issparse(X): # sklearn.validation._ensure_sparse_format to complicated
40-
raise ValueError(f"The estimator {name} does not work on sparse matrices")
4141

42+
def _validate_groups_values(X: nw.DataFrame, groups: List[int] | List[str]) -> None:
43+
X_cols = X.columns
44+
unexisting_cols = [g for g in groups if g not in X_cols]
4245

43-
def _check_grouping_columns(X_group, **kwargs) -> pd.DataFrame:
44-
"""Do basic checks on grouping columns"""
45-
# Do regular checks on numeric columns
46-
X_group_num = X_group.select_dtypes(include="number")
47-
if X_group_num.shape[1]:
48-
check_array(X_group_num, **kwargs)
46+
if len(unexisting_cols):
47+
raise ValueError(f"The following groups are not available in X: {unexisting_cols}")
4948

50-
# Only check missingness in object columns
51-
if X_group.select_dtypes(exclude="number").isnull().any(axis=None):
52-
raise ValueError("X has NaN values")
49+
if X.select(nw.col(groups).is_null().any()).to_numpy().squeeze().any():
50+
raise ValueError("Groups values have NaN")
5351

54-
# The grouping part we always want as a DataFrame with range index
55-
return X_group.reset_index(drop=True)
52+
53+
def _data_format_checks(X):
54+
"""Checks that X is not sparse nor has complex dtype"""
55+
_ensure_no_complex_data(X)
56+
57+
if issparse(X): # sklearn.validation._ensure_sparse_format to complicated
58+
msg = "Estimator does not work on sparse matrices"
59+
raise ValueError(msg)

sklego/meta/_shrinkage_utils.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from functools import partial
22

3+
import narwhals as nw
34
import numpy as np
45
from sklearn.utils.validation import check_is_fitted
56

6-
from sklego.common import expanding_list
7+
from sklego.common import as_list, expanding_list
78

89

910
def constant_shrinkage(group_sizes, alpha: float) -> np.ndarray:
@@ -193,20 +194,26 @@ def _fit_shrinkage_factors(self, frame, groups, most_granular_only=False):
193194
Whether to return only the shrinkage factors for the most granular group values.
194195
"""
195196
check_is_fitted(self, ["estimators_", "shrinkage_function_"])
196-
counts = frame.groupby(groups).size().rename("counts")
197+
counts = frame.group_by(groups).agg(nw.len().alias("counts"))
197198
all_grp_values = list(self.estimators_.keys())
198199

199200
if most_granular_only:
200-
all_grp_values = [grp_value for grp_value in all_grp_values if len(grp_value) == len(groups)]
201+
all_grp_values = [grp_value for grp_value in all_grp_values if len(as_list(grp_value)) == len(groups)]
201202

202203
hierarchical_counts = {
203-
grp_value: [counts.loc[subgroup].sum() for subgroup in expanding_list(grp_value, tuple)]
204+
grp_value: [
205+
# As zip is "zip shortest" and filter works with comma separate conditions:
206+
counts.filter(*[nw.col(c) == v for c, v in zip(groups, subgroup)])
207+
.select(nw.sum("counts"))
208+
.to_numpy()[0][0]
209+
for subgroup in expanding_list(grp_value, tuple)
210+
]
204211
for grp_value in all_grp_values
205212
}
206213

207214
shrinkage_factors = {
208-
grp_value: self.shrinkage_function_(counts, **self.shrinkage_kwargs)
209-
for grp_value, counts in hierarchical_counts.items()
215+
grp_value: self.shrinkage_function_(counts_, **self.shrinkage_kwargs)
216+
for grp_value, counts_ in hierarchical_counts.items()
210217
}
211218

212219
# Normalize and pad

0 commit comments

Comments
 (0)