|
4 | 4 | import posixpath |
5 | 5 | import sys |
6 | 6 | import time |
| 7 | +from collections.abc import Iterator |
7 | 8 |
|
8 | 9 | import multiprocess as mp |
9 | 10 | import pytest |
@@ -1028,3 +1029,80 @@ def summarize(left_path, right_value): |
1028 | 1029 | for g in range(groups) |
1029 | 1030 | } |
1030 | 1031 | assert {row["partition"]: row["total"] for row in records} == expected_totals |
| 1032 | + |
| 1033 | + |
| 1034 | +def test_agg_list_file_and_map_count(tmp_dir, test_session): |
| 1035 | + names = [ |
| 1036 | + "hotdogs.txt", |
| 1037 | + "dogs.txt", |
| 1038 | + "dog.txt", |
| 1039 | + "1dog.txt", |
| 1040 | + "dogatxt.txt", |
| 1041 | + "dog.txtx", |
| 1042 | + ] |
| 1043 | + |
| 1044 | + for name in names: |
| 1045 | + (tmp_dir / name).write_text(name, encoding="utf-8") |
| 1046 | + |
| 1047 | + base_chain = dc.read_storage(tmp_dir.as_uri(), session=test_session).order_by( |
| 1048 | + "file.path" |
| 1049 | + ) |
| 1050 | + |
| 1051 | + expected_files: list[File] = [] |
| 1052 | + for (file_obj,) in base_chain.select("file").to_iter(): |
| 1053 | + assert isinstance(file_obj, File) |
| 1054 | + expected_files.append(file_obj) |
| 1055 | + |
| 1056 | + def collect_files(file: list[File]) -> Iterator[list[File]]: |
| 1057 | + # Return the full collection for the partition |
| 1058 | + yield file |
| 1059 | + |
| 1060 | + def count_files(files: list[File]) -> int: |
| 1061 | + return len(files) |
| 1062 | + |
| 1063 | + ( |
| 1064 | + base_chain.agg(files=collect_files) |
| 1065 | + .map(num_files=count_files) |
| 1066 | + .save("temp_udf_types") |
| 1067 | + ) |
| 1068 | + |
| 1069 | + # Validate result |
| 1070 | + ds = dc.read_dataset("temp_udf_types", session=test_session) |
| 1071 | + rows = ds.select("num_files").to_list() |
| 1072 | + assert rows == [(len(expected_files),)] |
| 1073 | + |
| 1074 | + |
| 1075 | +def test_agg_list_file_persist_and_read(tmp_dir, test_session): |
| 1076 | + names = ["a.txt", "b.txt", "c.txt"] |
| 1077 | + |
| 1078 | + for name in names: |
| 1079 | + (tmp_dir / name).write_text(name, encoding="utf-8") |
| 1080 | + |
| 1081 | + base_chain = dc.read_storage(tmp_dir.as_uri(), session=test_session).order_by( |
| 1082 | + "file.path" |
| 1083 | + ) |
| 1084 | + |
| 1085 | + expected_files: list[File] = [] |
| 1086 | + for (file_obj,) in base_chain.select("file").to_iter(): |
| 1087 | + assert isinstance(file_obj, File) |
| 1088 | + expected_files.append(file_obj) |
| 1089 | + |
| 1090 | + def collect_files(file: list[File]) -> Iterator[list[File]]: |
| 1091 | + yield file |
| 1092 | + |
| 1093 | + (base_chain.agg(files=collect_files).save("temp_files_only")) |
| 1094 | + |
| 1095 | + # When reading back, we should get a list of File objects |
| 1096 | + ds = dc.read_dataset("temp_files_only", session=test_session) |
| 1097 | + vals = ds.select("files").to_list() |
| 1098 | + assert len(vals) == 1 |
| 1099 | + files_list = vals[0][0] |
| 1100 | + assert isinstance(files_list, list) |
| 1101 | + assert all(isinstance(f, File) for f in files_list) |
| 1102 | + |
| 1103 | + expected_sorted: list[File] = sorted(expected_files, key=lambda f: f.path) |
| 1104 | + actual_sorted: list[File] = sorted(files_list, key=lambda f: f.path) |
| 1105 | + |
| 1106 | + assert [f.model_dump() for f in actual_sorted] == [ |
| 1107 | + f.model_dump() for f in expected_sorted |
| 1108 | + ] |
0 commit comments