Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/datachain/lib/dc/datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,10 +1234,11 @@ def group_by( # noqa: C901, PLR0912
schema_partition_by.append(col)
else:
# BaseModel or other - add flattened columns directly
# Use underscores to flatten the column name
# to avoid MultiIndex in pandas output
for column in cast("list[Column]", columns):
col_type = self.signals_schema.get_column_type(column.name)
schema_fields[column.name] = col_type
schema_partition_by.append(col)
# Use the column's db name which already uses underscores
signal_columns.append(column)
else:
# simple signal - but we need to check if it's a complex signal
# complex signal - only include the columns used for partitioning
Expand Down
30 changes: 26 additions & 4 deletions tests/unit/lib/test_partition_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,16 +302,37 @@ class Level2(BaseModel):
total=dc.func.sum("amount"),
count=dc.func.count(),
partition_by="nested.level1.name", # This should work
).to_list("nested.level1.name", "total")
)
list_result = result.to_list("nested__level1__name", "total")

assert len(result) == 3 # Should have 3 unique names: test1, test2, test3
assert len(list_result) == 3 # Should have 3 unique names: test1, test2, test3

# Check the grouped results
name_to_total = dict(result)
name_to_total = dict(list_result)
assert name_to_total["test1"] == 40 # 10 + 30 (grouped by name)
assert name_to_total["test2"] == 20
assert name_to_total["test3"] == 40

from datachain.query.schema import Column as C

mutate_result = (
chain.mutate(name=C("nested.level1.name"))
.group_by(
total=dc.func.sum("amount"),
count=dc.func.count(),
partition_by="name", # This should work
)
.to_pandas()
)

assert mutate_result["total"].tolist() == [40, 20, 40]
assert mutate_result.columns.tolist() == ["name", "total", "count"]

pd_result = result.to_pandas()
assert len(pd_result) == 3
assert pd_result.columns.tolist() == ["nested__level1__name", "total", "count"]
assert pd_result["total"].tolist() == [40, 20, 40]


def test_nested_column_agg_partition_by(test_session):
class Person(BaseModel):
Expand Down Expand Up @@ -375,11 +396,12 @@ class Simple(BaseModel):
session=test_session,
)

# The nested column is flattened to use underscores in the output
result = chain.group_by(
total=dc.func.sum("amount"),
count=dc.func.count(),
partition_by="simple.name",
).to_list("simple.name", "total")
).to_list("simple__name", "total")

assert len(result) == 2 # Should have 2 unique names

Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_datachain_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def test_order_of_steps(mock_get_listing):


def test_all_possible_steps(test_session):
# Fix this test
persons_ds_name = "dev.my_pr.persons"
players_ds_name = "dev.my_pr.players"

Expand Down Expand Up @@ -188,14 +189,16 @@ def agg_persons(persons):
.offset(2)
.limit(5)
.group_by(age_avg=func.avg("persons.ages"), partition_by="persons.name")
.select("persons.name", "age_avg")
.select(
"persons__name", "age_avg"
) # After group_by, nested columns are flattened
.subtract(
players_chain,
on=["persons.name"],
on=["persons__name"],
right_on=["player.name"],
)
.hash()
) == "bd685bd97746a8e0e012c7029c7f2c8b17fc7eb5b7a5cd8fa5dacada57d75a07"
) == "73ef5dc642ff85377b47554aeb7458e05ac54a9efe835ad3a3a9525d00675a7b"


def test_diff(test_session):
Expand Down
Loading