|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import string
|
| 4 | +from typing import TYPE_CHECKING |
4 | 5 |
|
5 | 6 | import hypothesis.strategies as st
|
6 | 7 | import pandas as pd
|
|
15 | 16 | from tests.utils import PANDAS_VERSION
|
16 | 17 | from tests.utils import get_module_version_as_tuple
|
17 | 18 |
|
| 19 | +if TYPE_CHECKING: |
| 20 | + from narwhals.series import Series |
| 21 | + |
18 | 22 |
|
19 | 23 | def test_maybe_align_index_pandas() -> None:
|
20 | 24 | df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}, index=[1, 2, 0]))
|
@@ -58,21 +62,109 @@ def test_maybe_align_index_polars() -> None:
|
58 | 62 | nw.maybe_align_index(df, s[1:])
|
59 | 63 |
|
60 | 64 |
|
61 |
| -def test_maybe_set_index_pandas() -> None: |
62 |
| - df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[1, 2, 0])) |
63 |
| - result = nw.maybe_set_index(df, "b") |
64 |
| - expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[1, 2, 0]).set_index( |
65 |
| - "b" |
66 |
| - ) |
| 65 | +@pytest.mark.parametrize( |
| 66 | + "column_names", |
| 67 | + ["b", ["a", "b"]], |
| 68 | +) |
| 69 | +def test_maybe_set_index_pandas_column_names( |
| 70 | + column_names: str | list[str] | None, |
| 71 | +) -> None: |
| 72 | + df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) |
| 73 | + result = nw.maybe_set_index(df, column_names) |
| 74 | + expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).set_index(column_names) |
67 | 75 | assert_frame_equal(nw.to_native(result), expected)
|
68 | 76 |
|
69 | 77 |
|
70 |
| -def test_maybe_set_index_polars() -> None: |
| 78 | +@pytest.mark.parametrize( |
| 79 | + "column_names", |
| 80 | + [ |
| 81 | + "b", |
| 82 | + ["a", "b"], |
| 83 | + ], |
| 84 | +) |
| 85 | +def test_maybe_set_index_polars_column_names( |
| 86 | + column_names: str | list[str] | None, |
| 87 | +) -> None: |
| 88 | + df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) |
| 89 | + result = nw.maybe_set_index(df, column_names) |
| 90 | + assert result is df |
| 91 | + |
| 92 | + |
| 93 | +@pytest.mark.parametrize( |
| 94 | + "native_df_or_series", |
| 95 | + [pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), pd.Series([0, 1, 2])], |
| 96 | +) |
| 97 | +@pytest.mark.parametrize( |
| 98 | + ("narwhals_index", "pandas_index"), |
| 99 | + [ |
| 100 | + (nw.from_native(pd.Series([1, 2, 0]), series_only=True), pd.Series([1, 2, 0])), |
| 101 | + ( |
| 102 | + [ |
| 103 | + nw.from_native(pd.Series([0, 1, 2]), series_only=True), |
| 104 | + nw.from_native(pd.Series([1, 2, 0]), series_only=True), |
| 105 | + ], |
| 106 | + [ |
| 107 | + pd.Series([0, 1, 2]), |
| 108 | + pd.Series([1, 2, 0]), |
| 109 | + ], |
| 110 | + ), |
| 111 | + ], |
| 112 | +) |
| 113 | +def test_maybe_set_index_pandas_direct_index( |
| 114 | + narwhals_index: Series | list[Series] | None, |
| 115 | + pandas_index: pd.Series | list[pd.Series] | None, |
| 116 | + native_df_or_series: pd.DataFrame | pd.Series, |
| 117 | +) -> None: |
| 118 | + df = nw.from_native(native_df_or_series, allow_series=True) |
| 119 | + result = nw.maybe_set_index(df, index=narwhals_index) |
| 120 | + if isinstance(native_df_or_series, pd.Series): |
| 121 | + native_df_or_series.index = pandas_index |
| 122 | + assert_series_equal(nw.to_native(result), native_df_or_series) |
| 123 | + else: |
| 124 | + expected = native_df_or_series.set_index(pandas_index) |
| 125 | + assert_frame_equal(nw.to_native(result), expected) |
| 126 | + |
| 127 | + |
| 128 | +@pytest.mark.parametrize( |
| 129 | + "index", |
| 130 | + [ |
| 131 | + nw.from_native(pd.Series([1, 2, 0]), series_only=True), |
| 132 | + [ |
| 133 | + nw.from_native(pd.Series([0, 1, 2]), series_only=True), |
| 134 | + nw.from_native(pd.Series([1, 2, 0]), series_only=True), |
| 135 | + ], |
| 136 | + ], |
| 137 | +) |
| 138 | +def test_maybe_set_index_polars_direct_index( |
| 139 | + index: Series | list[Series] | None, |
| 140 | +) -> None: |
71 | 141 | df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
|
72 |
| - result = nw.maybe_set_index(df, "b") |
| 142 | + result = nw.maybe_set_index(df, index=index) |
73 | 143 | assert result is df
|
74 | 144 |
|
75 | 145 |
|
| 146 | +def test_maybe_set_index_pandas_series_column_names() -> None: |
| 147 | + df = nw.from_native(pd.Series([0, 1, 2]), allow_series=True) |
| 148 | + with pytest.raises( |
| 149 | + ValueError, match="Cannot set index using column names on a Series" |
| 150 | + ): |
| 151 | + nw.maybe_set_index(df, column_names=["a"]) |
| 152 | + |
| 153 | + |
| 154 | +def test_maybe_set_index_pandas_either_index_or_column_names() -> None: |
| 155 | + df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) |
| 156 | + column_names = ["a", "b"] |
| 157 | + index = nw.from_native(pd.Series([0, 1, 2]), series_only=True) |
| 158 | + with pytest.raises( |
| 159 | + ValueError, match="Only one of `column_names` or `index` should be provided" |
| 160 | + ): |
| 161 | + nw.maybe_set_index(df, column_names=column_names, index=index) |
| 162 | + with pytest.raises( |
| 163 | + ValueError, match="Either `column_names` or `index` should be provided" |
| 164 | + ): |
| 165 | + nw.maybe_set_index(df) |
| 166 | + |
| 167 | + |
76 | 168 | def test_maybe_get_index_pandas() -> None:
|
77 | 169 | pandas_df = pd.DataFrame({"a": [1, 2, 3]}, index=[1, 2, 0])
|
78 | 170 | result = nw.maybe_get_index(nw.from_native(pandas_df))
|
|
0 commit comments