Skip to content

Commit 12078a9

Browse files
skritsotalakisStelios KritsotalakisMarcoGorelli
authored
test: Added test which tests df.filter(nw.col(a).is_in(df_other[b])) (#1616)
--------- Co-authored-by: Stelios Kritsotalakis <kstelios@DESKTOP-D65QO0G> Co-authored-by: Marco Gorelli <[email protected]>
1 parent 4cbf646 commit 12078a9

File tree

3 files changed

+82
-73
lines changed

3 files changed

+82
-73
lines changed

narwhals/_arrow/group_by.py

Lines changed: 73 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -137,85 +137,85 @@ def agg_arrow(
137137
all_simple_aggs = False
138138
break
139139

140-
if all_simple_aggs:
141-
# Mapping from output name to
142-
# (aggregation_args, pyarrow_output_name) # noqa: ERA001
143-
simple_aggregations: dict[str, tuple[tuple[Any, ...], str]] = {}
144-
for expr in exprs:
145-
if expr._depth == 0:
146-
# e.g. agg(nw.len()) # noqa: ERA001
147-
if (
148-
expr._output_names is None or expr._function_name != "len"
149-
): # pragma: no cover
150-
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
151-
raise AssertionError(msg)
152-
simple_aggregations[expr._output_names[0]] = (
153-
(keys[0], "count", pc.CountOptions(mode="all")),
154-
f"{keys[0]}_count",
155-
)
156-
continue
140+
if not all_simple_aggs:
141+
msg = (
142+
"Non-trivial complex aggregation found.\n\n"
143+
"Hint: you were probably trying to apply a non-elementary aggregation with a "
144+
"pyarrow table.\n"
145+
"Please rewrite your query such that group-by aggregations "
146+
"are elementary. For example, instead of:\n\n"
147+
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
148+
"use:\n\n"
149+
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
150+
)
151+
raise ValueError(msg)
157152

158-
# e.g. agg(nw.mean('a')) # noqa: ERA001
153+
# Mapping from output name to
154+
# (aggregation_args, pyarrow_output_name) # noqa: ERA001
155+
simple_aggregations: dict[str, tuple[tuple[Any, ...], str]] = {}
156+
for expr in exprs:
157+
if expr._depth == 0:
158+
# e.g. agg(nw.len()) # noqa: ERA001
159159
if (
160-
expr._depth != 1 or expr._root_names is None or expr._output_names is None
160+
expr._output_names is None or expr._function_name != "len"
161161
): # pragma: no cover
162162
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
163163
raise AssertionError(msg)
164-
165-
function_name = remove_prefix(expr._function_name, "col->")
166-
function_name, option = polars_to_arrow_aggregations().get(
167-
function_name, (function_name, None)
164+
simple_aggregations[expr._output_names[0]] = (
165+
(keys[0], "count", pc.CountOptions(mode="all")),
166+
f"{keys[0]}_count",
168167
)
168+
continue
169169

170-
for root_name, output_name in zip(expr._root_names, expr._output_names):
171-
simple_aggregations[output_name] = (
172-
(root_name, function_name, option),
173-
f"{root_name}_{function_name}",
174-
)
175-
176-
aggs: list[Any] = []
177-
expected_pyarrow_column_names = keys.copy()
178-
new_column_names = keys.copy()
179-
for output_name, (
180-
aggregation_args,
181-
pyarrow_output_name,
182-
) in simple_aggregations.items():
183-
aggs.append(aggregation_args)
184-
expected_pyarrow_column_names.append(pyarrow_output_name)
185-
new_column_names.append(output_name)
186-
187-
result_simple = grouped.aggregate(aggs)
188-
189-
# Rename columns, being very careful
190-
expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
191-
for idx, item in enumerate(expected_pyarrow_column_names):
192-
expected_old_names_indices[item].append(idx)
193-
if not (
194-
set(result_simple.column_names) == set(expected_pyarrow_column_names)
195-
and len(result_simple.column_names) == len(expected_pyarrow_column_names)
170+
# e.g. agg(nw.mean('a')) # noqa: ERA001
171+
if (
172+
expr._depth != 1 or expr._root_names is None or expr._output_names is None
196173
): # pragma: no cover
197-
msg = (
198-
f"Safety assertion failed, expected {expected_pyarrow_column_names} "
199-
f"got {result_simple.column_names}, "
200-
"please report a bug at https://github.com/narwhals-dev/narwhals/issues"
201-
)
174+
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
202175
raise AssertionError(msg)
203-
index_map: list[int] = [
204-
expected_old_names_indices[item].pop(0) for item in result_simple.column_names
205-
]
206-
new_column_names = [new_column_names[i] for i in index_map]
207-
208-
result_simple = result_simple.rename_columns(new_column_names)
209-
return from_dataframe(result_simple)
210-
211-
msg = (
212-
"Non-trivial complex aggregation found.\n\n"
213-
"Hint: you were probably trying to apply a non-elementary aggregation with a "
214-
"pyarrow table.\n"
215-
"Please rewrite your query such that group-by aggregations "
216-
"are elementary. For example, instead of:\n\n"
217-
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
218-
"use:\n\n"
219-
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
220-
)
221-
raise ValueError(msg)
176+
177+
function_name = remove_prefix(expr._function_name, "col->")
178+
function_name, option = polars_to_arrow_aggregations().get(
179+
function_name, (function_name, None)
180+
)
181+
182+
for root_name, output_name in zip(expr._root_names, expr._output_names):
183+
simple_aggregations[output_name] = (
184+
(root_name, function_name, option),
185+
f"{root_name}_{function_name}",
186+
)
187+
188+
aggs: list[Any] = []
189+
expected_pyarrow_column_names = keys.copy()
190+
new_column_names = keys.copy()
191+
for output_name, (
192+
aggregation_args,
193+
pyarrow_output_name,
194+
) in simple_aggregations.items():
195+
aggs.append(aggregation_args)
196+
expected_pyarrow_column_names.append(pyarrow_output_name)
197+
new_column_names.append(output_name)
198+
199+
result_simple = grouped.aggregate(aggs)
200+
201+
# Rename columns, being very careful
202+
expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
203+
for idx, item in enumerate(expected_pyarrow_column_names):
204+
expected_old_names_indices[item].append(idx)
205+
if not (
206+
set(result_simple.column_names) == set(expected_pyarrow_column_names)
207+
and len(result_simple.column_names) == len(expected_pyarrow_column_names)
208+
): # pragma: no cover
209+
msg = (
210+
f"Safety assertion failed, expected {expected_pyarrow_column_names} "
211+
f"got {result_simple.column_names}, "
212+
"please report a bug at https://github.com/narwhals-dev/narwhals/issues"
213+
)
214+
raise AssertionError(msg)
215+
index_map: list[int] = [
216+
expected_old_names_indices[item].pop(0) for item in result_simple.column_names
217+
]
218+
new_column_names = [new_column_names[i] for i in index_map]
219+
220+
result_simple = result_simple.rename_columns(new_column_names)
221+
return from_dataframe(result_simple)

narwhals/expr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,6 +1783,7 @@ def is_in(self, other: Any) -> Self:
17831783
b: [[true,true,false,false]]
17841784
"""
17851785
if isinstance(other, Iterable) and not isinstance(other, (str, bytes)):
1786+
other = extract_compliant(self, other)
17861787
return self.__class__(lambda plx: self._to_compliant_expr(plx).is_in(other))
17871788
else:
17881789
msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead."

tests/expr_and_series/is_in_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,11 @@ def test_is_in_other(constructor: Constructor) -> None:
3535
),
3636
):
3737
nw.from_native(df_raw).with_columns(contains=nw.col("a").is_in("sets"))
38+
39+
40+
def test_filter_is_in_with_series(constructor_eager: ConstructorEager) -> None:
41+
data = {"a": [1, 4, 2, 5], "b": [1, 0, 2, 0]}
42+
df = nw.from_native(constructor_eager(data), eager_only=True)
43+
result = df.filter(nw.col("a").is_in(df["b"]))
44+
expected = {"a": [1, 2], "b": [1, 2]}
45+
assert_equal_data(result, expected)

0 commit comments

Comments
 (0)