Skip to content

Commit 0929d52

Browse files
RiikFBruzzesi
andauthored
feat: support passing index object directly into maybe_set_index (#1319)
--------- Co-authored-by: FBruzzesi <[email protected]>
1 parent c694148 commit 0929d52

File tree

2 files changed

+167
-19
lines changed

2 files changed

+167
-19
lines changed

narwhals/utils.py

+67-11
Original file line numberDiff line numberDiff line change
@@ -274,15 +274,37 @@ def maybe_get_index(obj: T) -> Any | None:
274274
return None
275275

276276

277-
def maybe_set_index(df: T, column_names: str | list[str]) -> T:
277+
def maybe_set_index(
278+
obj: T,
279+
column_names: str | list[str] | None = None,
280+
*,
281+
index: Series | list[Series] | None = None,
282+
) -> T:
278283
"""
279-
Set columns `columns` to be the index of `df`, if `df` is pandas-like.
284+
Set the index of a DataFrame or a Series, if it's pandas-like.
285+
286+
Arguments:
287+
obj: object for which maybe set the index (can be either a Narwhals `DataFrame`
288+
or `Series`).
289+
column_names: name or list of names of the columns to set as index.
290+
For dataframes, only one of `column_names` and `index` can be specified but
291+
not both. If `column_names` is passed and `df` is a Series, then a
292+
`ValueError` is raised.
293+
index: series or list of series to set as index.
294+
295+
Raises:
296+
ValueError: If one of the following condition happens:
297+
298+
- none of `column_names` and `index` are provided
299+
- both `column_names` and `index` are provided
300+
- `column_names` is provided and `df` is a Series
280301
281302
Notes:
282-
This is only really intended for backwards-compatibility purposes,
283-
for example if your library already aligns indices for users.
303+
This is only really intended for backwards-compatibility purposes, for example if
304+
your library already aligns indices for users.
284305
If you're designing a new library, we highly encourage you to not
285306
rely on the Index.
307+
286308
For non-pandas-like inputs, this is a no-op.
287309
288310
Examples:
@@ -297,15 +319,49 @@ def maybe_set_index(df: T, column_names: str | list[str]) -> T:
297319
4 1
298320
5 2
299321
"""
300-
df_any = cast(Any, df)
301-
native_frame = to_native(df_any)
302-
if is_pandas_like_dataframe(native_frame):
322+
323+
df_any = cast(Any, obj)
324+
native_obj = to_native(df_any)
325+
326+
if column_names is not None and index is not None:
327+
msg = "Only one of `column_names` or `index` should be provided"
328+
raise ValueError(msg)
329+
330+
if not column_names and not index:
331+
msg = "Either `column_names` or `index` should be provided"
332+
raise ValueError(msg)
333+
334+
if index is not None:
335+
keys = (
336+
[to_native(idx, pass_through=True) for idx in index]
337+
if _is_iterable(index)
338+
else to_native(index, pass_through=True)
339+
)
340+
else:
341+
keys = column_names
342+
343+
if is_pandas_like_dataframe(native_obj):
303344
return df_any._from_compliant_dataframe( # type: ignore[no-any-return]
304-
df_any._compliant_frame._from_native_frame(
305-
native_frame.set_index(column_names)
306-
)
345+
df_any._compliant_frame._from_native_frame(native_obj.set_index(keys))
346+
)
347+
elif is_pandas_like_series(native_obj):
348+
if column_names:
349+
msg = "Cannot set index using column names on a Series"
350+
raise ValueError(msg)
351+
352+
if (
353+
df_any._compliant_series._implementation is Implementation.PANDAS
354+
and df_any._compliant_series._backend_version < (1,)
355+
): # pragma: no cover
356+
native_obj = native_obj.set_axis(keys, inplace=False)
357+
else:
358+
native_obj = native_obj.set_axis(keys)
359+
360+
return df_any._from_compliant_series( # type: ignore[no-any-return]
361+
df_any._compliant_series._from_native_series(native_obj)
307362
)
308-
return df_any # type: ignore[no-any-return]
363+
else:
364+
return df_any # type: ignore[no-any-return]
309365

310366

311367
def maybe_reset_index(obj: T) -> T:

tests/utils_test.py

+100-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import string
4+
from typing import TYPE_CHECKING
45

56
import hypothesis.strategies as st
67
import pandas as pd
@@ -15,6 +16,9 @@
1516
from tests.utils import PANDAS_VERSION
1617
from tests.utils import get_module_version_as_tuple
1718

19+
if TYPE_CHECKING:
20+
from narwhals.series import Series
21+
1822

1923
def test_maybe_align_index_pandas() -> None:
2024
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}, index=[1, 2, 0]))
@@ -58,21 +62,109 @@ def test_maybe_align_index_polars() -> None:
5862
nw.maybe_align_index(df, s[1:])
5963

6064

61-
def test_maybe_set_index_pandas() -> None:
62-
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[1, 2, 0]))
63-
result = nw.maybe_set_index(df, "b")
64-
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[1, 2, 0]).set_index(
65-
"b"
66-
)
65+
@pytest.mark.parametrize(
66+
"column_names",
67+
["b", ["a", "b"]],
68+
)
69+
def test_maybe_set_index_pandas_column_names(
70+
column_names: str | list[str] | None,
71+
) -> None:
72+
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
73+
result = nw.maybe_set_index(df, column_names)
74+
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).set_index(column_names)
6775
assert_frame_equal(nw.to_native(result), expected)
6876

6977

70-
def test_maybe_set_index_polars() -> None:
78+
@pytest.mark.parametrize(
79+
"column_names",
80+
[
81+
"b",
82+
["a", "b"],
83+
],
84+
)
85+
def test_maybe_set_index_polars_column_names(
86+
column_names: str | list[str] | None,
87+
) -> None:
88+
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
89+
result = nw.maybe_set_index(df, column_names)
90+
assert result is df
91+
92+
93+
@pytest.mark.parametrize(
94+
"native_df_or_series",
95+
[pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), pd.Series([0, 1, 2])],
96+
)
97+
@pytest.mark.parametrize(
98+
("narwhals_index", "pandas_index"),
99+
[
100+
(nw.from_native(pd.Series([1, 2, 0]), series_only=True), pd.Series([1, 2, 0])),
101+
(
102+
[
103+
nw.from_native(pd.Series([0, 1, 2]), series_only=True),
104+
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
105+
],
106+
[
107+
pd.Series([0, 1, 2]),
108+
pd.Series([1, 2, 0]),
109+
],
110+
),
111+
],
112+
)
113+
def test_maybe_set_index_pandas_direct_index(
114+
narwhals_index: Series | list[Series] | None,
115+
pandas_index: pd.Series | list[pd.Series] | None,
116+
native_df_or_series: pd.DataFrame | pd.Series,
117+
) -> None:
118+
df = nw.from_native(native_df_or_series, allow_series=True)
119+
result = nw.maybe_set_index(df, index=narwhals_index)
120+
if isinstance(native_df_or_series, pd.Series):
121+
native_df_or_series.index = pandas_index
122+
assert_series_equal(nw.to_native(result), native_df_or_series)
123+
else:
124+
expected = native_df_or_series.set_index(pandas_index)
125+
assert_frame_equal(nw.to_native(result), expected)
126+
127+
128+
@pytest.mark.parametrize(
129+
"index",
130+
[
131+
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
132+
[
133+
nw.from_native(pd.Series([0, 1, 2]), series_only=True),
134+
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
135+
],
136+
],
137+
)
138+
def test_maybe_set_index_polars_direct_index(
139+
index: Series | list[Series] | None,
140+
) -> None:
71141
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
72-
result = nw.maybe_set_index(df, "b")
142+
result = nw.maybe_set_index(df, index=index)
73143
assert result is df
74144

75145

146+
def test_maybe_set_index_pandas_series_column_names() -> None:
147+
df = nw.from_native(pd.Series([0, 1, 2]), allow_series=True)
148+
with pytest.raises(
149+
ValueError, match="Cannot set index using column names on a Series"
150+
):
151+
nw.maybe_set_index(df, column_names=["a"])
152+
153+
154+
def test_maybe_set_index_pandas_either_index_or_column_names() -> None:
155+
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
156+
column_names = ["a", "b"]
157+
index = nw.from_native(pd.Series([0, 1, 2]), series_only=True)
158+
with pytest.raises(
159+
ValueError, match="Only one of `column_names` or `index` should be provided"
160+
):
161+
nw.maybe_set_index(df, column_names=column_names, index=index)
162+
with pytest.raises(
163+
ValueError, match="Either `column_names` or `index` should be provided"
164+
):
165+
nw.maybe_set_index(df)
166+
167+
76168
def test_maybe_get_index_pandas() -> None:
77169
pandas_df = pd.DataFrame({"a": [1, 2, 3]}, index=[1, 2, 0])
78170
result = nw.maybe_get_index(nw.from_native(pandas_df))

0 commit comments

Comments
 (0)