Skip to content

Commit 2a16835

Browse files
authored
feat: Generate nicer-looking SQL when grouping by expressions (#2536)
* feat: Generate nicer-looking SQL when grouping by expressions * boolean-named columns * add really mean test case
1 parent 6c110ca commit 2a16835

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

narwhals/_compliant/group_by.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,21 @@ def _parse_expr_keys(
122122
no overlap with any existing column name.
123123
- Add these temporary columns to the compliant dataframe.
124124
"""
125-
suffix_token = "_" * (max(len(str(c)) for c in compliant_frame.columns) + 1)
125+
tmp_name_length = max(len(str(c)) for c in compliant_frame.columns) + 1
126+
127+
def _temporary_name(key: str) -> str:
128+
# 5 is the length of `__tmp`
129+
key_str = str(key) # pandas allows non-string column names :sob:
130+
return f"_{key_str}_tmp{'_' * (tmp_name_length - len(key_str) - 5)}"
131+
126132
output_names = compliant_frame._evaluate_aliases(*keys)
127133

128134
safe_keys = [
129135
# multi-output expression cannot have duplicate names, hence it's safe to suffix
130-
key.name.suffix(suffix_token)
136+
key.name.map(_temporary_name)
131137
if (metadata := key._metadata) and metadata.expansion_kind.is_multi_output()
132138
# otherwise it's single named and we can use Expr.alias
133-
else key.alias(f"{new_name}{suffix_token}")
139+
else key.alias(_temporary_name(new_name))
134140
for key, new_name in zip(keys, output_names)
135141
]
136142
return (

tests/group_by_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -584,3 +584,10 @@ def test_group_by_selector(constructor: Constructor) -> None:
584584
)
585585
expected = {"a": [1, 1], "b": [4, 6], "c": [8.0, 9.0]}
586586
assert_equal_data(result, expected)
587+
588+
589+
def test_renaming_edge_case(constructor: Constructor) -> None:
590+
data = {"a": [0, 0, 0], "_a_tmp": [1, 2, 3], "b": [4, 5, 6]}
591+
result = nw.from_native(constructor(data)).group_by(nw.col("a")).agg(nw.all().min())
592+
expected = {"a": [0], "_a_tmp": [1], "b": [4]}
593+
assert_equal_data(result, expected)

0 commit comments

Comments
 (0)