|
1 |
| -from typing import Tuple |
| 1 | +from __future__ import annotations |
2 | 2 |
|
3 |
| -import numpy as np |
| 3 | +from typing import List |
| 4 | + |
| 5 | +import narwhals as nw |
4 | 6 | import pandas as pd
|
5 | 7 | from scipy.sparse import issparse
|
6 | 8 | from sklearn.utils import check_array
|
7 | 9 | from sklearn.utils.validation import _ensure_no_complex_data
|
8 | 10 |
|
9 | 11 |
|
10 |
| -def _split_groups_and_values( |
11 |
| - X, groups, name="", min_value_cols=1, check_X=True, **kwargs |
12 |
| -) -> Tuple[pd.DataFrame, np.ndarray]: |
13 |
| - _data_format_checks(X, name=name) |
14 |
| - check_array(X, ensure_min_features=min_value_cols, dtype=None, force_all_finite=False) |
| 12 | +def parse_X_y(X, y, groups, check_X=True, **kwargs) -> nw.DataFrame: |
| 13 | + """Converts X, y to narwhals dataframe. |
15 | 14 |
|
16 |
| - try: |
17 |
| - if isinstance(X, pd.DataFrame): |
18 |
| - X_group = X.loc[:, groups] |
19 |
| - X_value = X.drop(columns=groups).values |
20 |
| - else: |
21 |
| - X = np.asarray(X) # deals with `_NotAnArray` case |
22 |
| - X_group = pd.DataFrame(X[:, groups]) |
23 |
| - pos_indexes = range(X.shape[1]) |
24 |
| - X_value = np.delete(X, [pos_indexes[g] for g in groups], axis=1) |
25 |
| - except (KeyError, IndexError): |
26 |
| - raise ValueError(f"Could not drop groups {groups} from columns of X") |
| 15 | + If it is not a supported dataframe, it uses pandas constructor as a fallback. |
27 | 16 |
|
28 |
| - X_group = _check_grouping_columns(X_group, **kwargs) |
| 17 | + Additionally, data checks are performed. |
| 18 | + """ |
| 19 | + # Check raw X |
| 20 | + _data_format_checks(X) |
29 | 21 |
|
30 |
| - if check_X: |
31 |
| - X_value = check_array(X_value, **kwargs) |
| 22 | + # Convert X to Narwhals frame |
| 23 | + X = nw.from_native(X, strict=False, eager_only=True) |
| 24 | + if not isinstance(X, nw.DataFrame): |
| 25 | + X = nw.from_native(pd.DataFrame(X)) |
32 | 26 |
|
33 |
| - return X_group, X_value |
| 27 | + # Check groups and feaures values |
| 28 | + if groups is not None: |
| 29 | + _validate_groups_values(X, groups) |
34 | 30 |
|
| 31 | + if check_X: |
| 32 | + check_array(X.drop(groups), **kwargs) |
35 | 33 |
|
36 |
| -def _data_format_checks(X, name): |
37 |
| - _ensure_no_complex_data(X) |
| 34 | + # Convert y and assign it to the frame |
| 35 | + n_samples = X.shape[0] |
| 36 | + native_space = nw.get_native_namespace(X) |
| 37 | + |
| 38 | + y_native = native_space.Series([None] * n_samples) if y is None else native_space.Series(y) |
| 39 | + return X.with_columns(__sklego_target__=nw.from_native(y_native, allow_series=True)) |
38 | 40 |
|
39 |
| - if issparse(X): # sklearn.validation._ensure_sparse_format to complicated |
40 |
| - raise ValueError(f"The estimator {name} does not work on sparse matrices") |
41 | 41 |
|
| 42 | +def _validate_groups_values(X: nw.DataFrame, groups: List[int] | List[str]) -> None: |
| 43 | + X_cols = X.columns |
| 44 | + unexisting_cols = [g for g in groups if g not in X_cols] |
42 | 45 |
|
43 |
| -def _check_grouping_columns(X_group, **kwargs) -> pd.DataFrame: |
44 |
| - """Do basic checks on grouping columns""" |
45 |
| - # Do regular checks on numeric columns |
46 |
| - X_group_num = X_group.select_dtypes(include="number") |
47 |
| - if X_group_num.shape[1]: |
48 |
| - check_array(X_group_num, **kwargs) |
| 46 | + if len(unexisting_cols): |
| 47 | + raise ValueError(f"The following groups are not available in X: {unexisting_cols}") |
49 | 48 |
|
50 |
| - # Only check missingness in object columns |
51 |
| - if X_group.select_dtypes(exclude="number").isnull().any(axis=None): |
52 |
| - raise ValueError("X has NaN values") |
| 49 | + if X.select(nw.col(groups).is_null().any()).to_numpy().squeeze().any(): |
| 50 | + raise ValueError("Groups values have NaN") |
53 | 51 |
|
54 |
| - # The grouping part we always want as a DataFrame with range index |
55 |
| - return X_group.reset_index(drop=True) |
| 52 | + |
| 53 | +def _data_format_checks(X): |
| 54 | + """Checks that X is not sparse nor has complex dtype""" |
| 55 | + _ensure_no_complex_data(X) |
| 56 | + |
| 57 | + if issparse(X): # sklearn.validation._ensure_sparse_format to complicated |
| 58 | + msg = "Estimator does not work on sparse matrices" |
| 59 | + raise ValueError(msg) |
0 commit comments