Skip to content

Commit 494bde9

Browse files
authored
fix(udf): handle models serialization in complex types (#1459)
1 parent 011c0cf commit 494bde9

File tree

3 files changed

+91
-8
lines changed

3 files changed

+91
-8
lines changed

src/datachain/data_storage/warehouse.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def _to_jsonable(self, obj: Any) -> Any:
8383
"""
8484

8585
if ModelStore.is_pydantic(type(obj)):
86-
return obj.model_dump()
86+
# Use Pydantic's JSON mode to ensure datetime and other non-JSON
87+
# native types are serialized in a compatible way.
88+
return obj.model_dump(mode="json")
8789

8890
if isinstance(obj, dict):
8991
out: dict[str, Any] = {}

src/datachain/lib/udf.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -560,13 +560,16 @@ def run(
560560
self.setup()
561561

562562
for batch in udf_inputs:
563-
udf_args = zip(
564-
*[
565-
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
566-
for row in batch
567-
],
568-
strict=False,
569-
)
563+
prepared_rows = [
564+
self._prepare_row(row, udf_fields, catalog, cache, download_cb)
565+
for row in batch
566+
]
567+
batched_args = zip(*prepared_rows, strict=False)
568+
# Convert aggregated column values to lists. This keeps behavior
569+
# consistent with the type hints promoted in the public API.
570+
udf_args = [
571+
list(arg) if isinstance(arg, tuple) else arg for arg in batched_args
572+
]
570573
result_objs = self.process_safe(udf_args)
571574
udf_outputs = (self._flatten_row(row) for row in result_objs)
572575
output = (

tests/func/test_udf.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import posixpath
55
import sys
66
import time
7+
from collections.abc import Iterator
78

89
import multiprocess as mp
910
import pytest
@@ -1028,3 +1029,80 @@ def summarize(left_path, right_value):
10281029
for g in range(groups)
10291030
}
10301031
assert {row["partition"]: row["total"] for row in records} == expected_totals
1032+
1033+
1034+
def test_agg_list_file_and_map_count(tmp_dir, test_session):
1035+
names = [
1036+
"hotdogs.txt",
1037+
"dogs.txt",
1038+
"dog.txt",
1039+
"1dog.txt",
1040+
"dogatxt.txt",
1041+
"dog.txtx",
1042+
]
1043+
1044+
for name in names:
1045+
(tmp_dir / name).write_text(name, encoding="utf-8")
1046+
1047+
base_chain = dc.read_storage(tmp_dir.as_uri(), session=test_session).order_by(
1048+
"file.path"
1049+
)
1050+
1051+
expected_files: list[File] = []
1052+
for (file_obj,) in base_chain.select("file").to_iter():
1053+
assert isinstance(file_obj, File)
1054+
expected_files.append(file_obj)
1055+
1056+
def collect_files(file: list[File]) -> Iterator[list[File]]:
1057+
# Return the full collection for the partition
1058+
yield file
1059+
1060+
def count_files(files: list[File]) -> int:
1061+
return len(files)
1062+
1063+
(
1064+
base_chain.agg(files=collect_files)
1065+
.map(num_files=count_files)
1066+
.save("temp_udf_types")
1067+
)
1068+
1069+
# Validate result
1070+
ds = dc.read_dataset("temp_udf_types", session=test_session)
1071+
rows = ds.select("num_files").to_list()
1072+
assert rows == [(len(expected_files),)]
1073+
1074+
1075+
def test_agg_list_file_persist_and_read(tmp_dir, test_session):
1076+
names = ["a.txt", "b.txt", "c.txt"]
1077+
1078+
for name in names:
1079+
(tmp_dir / name).write_text(name, encoding="utf-8")
1080+
1081+
base_chain = dc.read_storage(tmp_dir.as_uri(), session=test_session).order_by(
1082+
"file.path"
1083+
)
1084+
1085+
expected_files: list[File] = []
1086+
for (file_obj,) in base_chain.select("file").to_iter():
1087+
assert isinstance(file_obj, File)
1088+
expected_files.append(file_obj)
1089+
1090+
def collect_files(file: list[File]) -> Iterator[list[File]]:
1091+
yield file
1092+
1093+
(base_chain.agg(files=collect_files).save("temp_files_only"))
1094+
1095+
# When reading back, we should get a list of File objects
1096+
ds = dc.read_dataset("temp_files_only", session=test_session)
1097+
vals = ds.select("files").to_list()
1098+
assert len(vals) == 1
1099+
files_list = vals[0][0]
1100+
assert isinstance(files_list, list)
1101+
assert all(isinstance(f, File) for f in files_list)
1102+
1103+
expected_sorted: list[File] = sorted(expected_files, key=lambda f: f.path)
1104+
actual_sorted: list[File] = sorted(files_list, key=lambda f: f.path)
1105+
1106+
assert [f.model_dump() for f in actual_sorted] == [
1107+
f.model_dump() for f in expected_sorted
1108+
]

0 commit comments

Comments
 (0)