Skip to content

Commit a4b1a13

Browse files
authored
Merge pull request #43 from MITLibraries/TIMX-371-dedupe-records
TIMX 371 - dedupe records
2 parents 059c576 + 875be3c commit a4b1a13

File tree

5 files changed

+405
-71
lines changed

5 files changed

+405
-71
lines changed

abdiff/core/collate_ab_transforms.py

Lines changed: 156 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# ruff: noqa: TRY003
2+
13
import itertools
24
import json
35
import logging
@@ -8,6 +10,7 @@
810

911
import duckdb
1012
import ijson
13+
import pandas as pd
1114
import pyarrow as pa
1215

1316
from abdiff.core.exceptions import OutputValidationError
@@ -20,15 +23,21 @@
2023
(
2124
pa.field("timdex_record_id", pa.string()),
2225
pa.field("source", pa.string()),
26+
pa.field("run_date", pa.date32()),
27+
pa.field("run_type", pa.string()),
28+
pa.field("action", pa.string()),
2329
pa.field("record", pa.binary()),
2430
pa.field("version", pa.string()),
2531
pa.field("transformed_file_name", pa.string()),
2632
)
2733
)
28-
JOINED_DATASET_SCHEMA = pa.schema(
34+
COLLATED_DATASET_SCHEMA = pa.schema(
2935
(
3036
pa.field("timdex_record_id", pa.string()),
3137
pa.field("source", pa.string()),
38+
pa.field("run_date", pa.date32()),
39+
pa.field("run_type", pa.string()),
40+
pa.field("action", pa.string()),
3241
pa.field("record_a", pa.binary()),
3342
pa.field("record_b", pa.binary()),
3443
)
@@ -40,18 +49,19 @@ def collate_ab_transforms(
4049
) -> str:
4150
"""Collates A/B transformed files into a Parquet dataset.
4251
43-
This process can be summarized into two (2) important steps:
52+
This process can be summarized into two (3) important steps:
4453
1. Write all transformed JSON records into a temporary Parquet dataset
4554
partitioned by the transformed file name.
4655
2. For every transformed file, use DuckDB to join A/B Parquet tables
4756
using the TIMDEX record ID and write joined records to a Parquet dataset.
48-
49-
This function (and its subfunctions) uses DuckDB, generators, and batching to
50-
write records to Parquet datasets in a memory-efficient manner.
57+
3. Dedupe joined records to ensure that only the most recent, not "deleted"
58+
timdex_record_id is present in final output.
5159
"""
5260
transformed_dataset_path = tempfile.TemporaryDirectory()
61+
joined_dataset_path = tempfile.TemporaryDirectory()
5362
collated_dataset_path = str(Path(run_directory) / "collated")
5463

64+
# build temporary transformed dataset
5565
transformed_written_files = write_to_dataset(
5666
get_transformed_batches_iter(run_directory, ab_transformed_file_lists),
5767
schema=TRANSFORMED_DATASET_SCHEMA,
@@ -62,15 +72,30 @@ def collate_ab_transforms(
6272
f"Wrote {len(transformed_written_files)} parquet file(s) to transformed dataset"
6373
)
6474

75+
# build temporary collated dataset
6576
joined_written_files = write_to_dataset(
6677
get_joined_batches_iter(transformed_dataset_path.name),
67-
base_dir=collated_dataset_path,
68-
schema=JOINED_DATASET_SCHEMA,
78+
base_dir=joined_dataset_path.name,
79+
schema=COLLATED_DATASET_SCHEMA,
6980
)
7081
logger.info(f"Wrote {len(joined_written_files)} parquet file(s) to collated dataset")
7182

83+
# build final deduped and collated dataset
84+
deduped_written_files = write_to_dataset(
85+
get_deduped_batches_iter(joined_dataset_path.name),
86+
base_dir=collated_dataset_path,
87+
schema=COLLATED_DATASET_SCHEMA,
88+
)
89+
logger.info(
90+
f"Wrote {len(deduped_written_files)} parquet file(s) to deduped collated dataset"
91+
)
92+
7293
validate_output(collated_dataset_path)
7394

95+
# ensure temporary artifacts removed
96+
transformed_dataset_path.cleanup()
97+
joined_dataset_path.cleanup()
98+
7499
return collated_dataset_path
75100

76101

@@ -85,22 +110,44 @@ def get_transformed_records_iter(
85110
86111
* timdex_record_id: The TIMDEX record ID.
87112
* source: The shorthand name of the source as denoted in by Transmogrifier
88-
(see https://github.com/MITLibraries/transmogrifier/blob/main/transmogrifier/config.py).
113+
* run_date: Run date from TIMDEX ETL
114+
* run_type: "full" or "daily"
115+
* action: "index" or "delete"
89116
* record: The TIMDEX record serialized to a JSON string then encoded to bytes.
90117
* version: The version of the transform, parsed from the absolute filepath to a
91118
transformed file.
92119
* transformed_file_name: The name of the transformed file, excluding file extension.
93120
"""
94121
version = get_transform_version(transformed_file)
95122
filename_details = parse_timdex_filename(transformed_file)
96-
with open(transformed_file, "rb") as file:
97-
for record in ijson.items(file, "item"):
123+
124+
base_record = {
125+
"source": filename_details["source"],
126+
"run_date": filename_details["run-date"],
127+
"run_type": filename_details["run-type"],
128+
"action": filename_details["action"],
129+
"version": version,
130+
"transformed_file_name": transformed_file.split("/")[-1],
131+
}
132+
133+
# handle JSON files with records to index
134+
if transformed_file.endswith(".json"):
135+
with open(transformed_file, "rb") as file:
136+
for record in ijson.items(file, "item"):
137+
yield {
138+
**base_record,
139+
"timdex_record_id": record["timdex_record_id"],
140+
"record": json.dumps(record).encode(),
141+
}
142+
143+
# handle TXT files with records to delete
144+
else:
145+
deleted_records_df = pd.read_csv(transformed_file, header=None)
146+
for row in deleted_records_df.itertuples():
98147
yield {
99-
"timdex_record_id": record["timdex_record_id"],
100-
"source": filename_details["source"],
101-
"record": json.dumps(record).encode(),
102-
"version": version,
103-
"transformed_file_name": transformed_file.split("/")[-1],
148+
**base_record,
149+
"timdex_record_id": row[1],
150+
"record": None,
104151
}
105152

106153

@@ -192,6 +239,9 @@ def get_joined_batches_iter(dataset_directory: str) -> Generator[pa.RecordBatch]
192239
SELECT
193240
COALESCE(a.timdex_record_id, b.timdex_record_id) timdex_record_id,
194241
COALESCE(a.source, b.source) source,
242+
COALESCE(a.run_date, b.run_date) run_date,
243+
COALESCE(a.run_type, b.run_type) run_type,
244+
COALESCE(a.action, b.action) "action",
195245
a.record as record_a,
196246
b.record as record_b
197247
FROM a
@@ -210,15 +260,79 @@ def get_joined_batches_iter(dataset_directory: str) -> Generator[pa.RecordBatch]
210260
break
211261

212262

263+
def get_deduped_batches_iter(dataset_directory: str) -> Generator[pa.RecordBatch]:
264+
"""Yield pyarrow.RecordBatch objects of deduped rows from the joined dataset.
265+
266+
ABDiff should be able to handle many input files, where a single timdex_record_id may
267+
be duplicated across multiple files ("full" vs "daily" runs, incrementing date runs,
268+
etc.)
269+
270+
This function writes the final dataset by deduping records from the temporary collated
271+
dataset, given the following logic:
272+
- use the MOST RECENT record based on 'run_date'
273+
- if the MOST RECENT record is action='delete', then omit record entirely
274+
275+
The same mechanism is used by get_joined_batches_iter() to perform a DuckDB query then
276+
stream write batches to a parquet dataset.
277+
"""
278+
with duckdb.connect(":memory:") as con:
279+
280+
results = con.execute(
281+
"""
282+
WITH collated as (
283+
select * from read_parquet($collated_parquet_glob, hive_partitioning=true)
284+
),
285+
latest_records AS (
286+
SELECT
287+
*,
288+
ROW_NUMBER() OVER (
289+
PARTITION BY timdex_record_id
290+
ORDER BY run_date DESC
291+
) AS rn
292+
FROM collated
293+
),
294+
deduped_records AS (
295+
SELECT *
296+
FROM latest_records
297+
WHERE rn = 1 AND action != 'delete'
298+
)
299+
SELECT
300+
timdex_record_id,
301+
source,
302+
run_date,
303+
run_type,
304+
action,
305+
record_a,
306+
record_b
307+
FROM deduped_records;
308+
""",
309+
{
310+
"collated_parquet_glob": f"{dataset_directory}/**/*.parquet",
311+
},
312+
).fetch_record_batch(READ_BATCH_SIZE)
313+
314+
while True:
315+
try:
316+
yield results.read_next_batch()
317+
except StopIteration:
318+
break # pragma: nocover
319+
320+
213321
def validate_output(dataset_path: str) -> None:
214322
"""Validate the output of collate_ab_transforms.
215323
216324
This function checks whether the collated dataset is empty
217325
and whether any or both 'record_a' or 'record_b' columns are
218326
totally empty.
219327
"""
328+
329+
def fetch_single_value(query: str) -> int:
330+
result = con.execute(query).fetchone()
331+
if result is None:
332+
raise RuntimeError(f"Query returned no results: {query}") # pragma: nocover
333+
return int(result[0])
334+
220335
with duckdb.connect(":memory:") as con:
221-
# create view of collated table
222336
con.execute(
223337
f"""
224338
CREATE VIEW collated AS (
@@ -228,41 +342,48 @@ def validate_output(dataset_path: str) -> None:
228342
)
229343

230344
# check if the table is empty
231-
record_count = con.execute("SELECT COUNT(*) FROM collated").fetchone()[0] # type: ignore[index]
345+
record_count = fetch_single_value("SELECT COUNT(*) FROM collated")
232346
if record_count == 0:
233-
raise OutputValidationError( # noqa: TRY003
347+
raise OutputValidationError(
234348
"The collated dataset does not contain any records."
235349
)
236350

237351
# check if any of the 'record_*' columns are empty
238-
record_a_null_count = con.execute(
352+
record_a_null_count = fetch_single_value(
239353
"SELECT COUNT(*) FROM collated WHERE record_a ISNULL"
240-
).fetchone()[
241-
0
242-
] # type: ignore[index]
243-
244-
record_b_null_count = con.execute(
354+
)
355+
record_b_null_count = fetch_single_value(
245356
"SELECT COUNT(*) FROM collated WHERE record_b ISNULL"
246-
).fetchone()[
247-
0
248-
] # type: ignore[index]
357+
)
249358

250359
if record_count in {record_a_null_count, record_b_null_count}:
251-
raise OutputValidationError( # noqa: TRY003
360+
raise OutputValidationError(
252361
"At least one or both record column(s) ['record_a', 'record_b'] "
253362
"in the collated dataset are empty."
254363
)
255364

365+
# check that timdex_record_id column is unique
366+
non_unique_count = fetch_single_value(
367+
"""
368+
SELECT COUNT(*)
369+
FROM (
370+
SELECT timdex_record_id
371+
FROM collated
372+
GROUP BY timdex_record_id
373+
HAVING COUNT(*) > 1
374+
) as duplicates;
375+
"""
376+
)
377+
if non_unique_count > 0:
378+
raise OutputValidationError(
379+
"The collated dataset contains duplicate 'timdex_record_id' records."
380+
)
381+
256382

257383
def get_transform_version(transformed_filepath: str) -> str:
258384
"""Get A/B transform version, either 'a' or 'b'."""
259-
match_result = re.match(
260-
r".*transformed\/(.*)\/.*.json",
261-
transformed_filepath,
262-
)
385+
match_result = re.match(r".*transformed\/(.*)\/.*", transformed_filepath)
263386
if not match_result:
264-
raise ValueError( # noqa: TRY003
265-
f"Transformed filepath is invalid: {transformed_filepath}."
266-
)
387+
raise ValueError(f"Transformed filepath is invalid: {transformed_filepath}.")
267388

268389
return match_result.groups()[0]

abdiff/core/run_ab_transforms.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def run_ab_transforms(
110110
"to complete successfully."
111111
)
112112
ab_transformed_file_lists = get_transformed_files(run_directory)
113-
validate_output(ab_transformed_file_lists, len(input_files))
113+
validate_output(ab_transformed_file_lists, input_files)
114114

115115
# write and return results
116116
run_data = {
@@ -278,11 +278,11 @@ def get_transformed_files(run_directory: str) -> tuple[list[str], ...]:
278278
279279
Returns:
280280
tuple[list[str]]: Tuple containing lists of paths to transformed
281-
JSON files for each image, relative to 'run_directory'.
281+
JSON and TXT (deletions) files for each image, relative to 'run_directory'.
282282
"""
283283
ordered_files = []
284284
for version in ["a", "b"]:
285-
absolute_filepaths = glob.glob(f"{run_directory}/transformed/{version}/*.json")
285+
absolute_filepaths = glob.glob(f"{run_directory}/transformed/{version}/*")
286286
relative_filepaths = [
287287
os.path.relpath(file, run_directory) for file in absolute_filepaths
288288
]
@@ -291,24 +291,39 @@ def get_transformed_files(run_directory: str) -> tuple[list[str], ...]:
291291

292292

293293
def validate_output(
294-
ab_transformed_file_lists: tuple[list[str], ...], input_files_count: int
294+
ab_transformed_file_lists: tuple[list[str], ...], input_files: list[str]
295295
) -> None:
296296
"""Validate the output of run_ab_transforms.
297297
298-
This function checks that the number of files in each of the A/B
299-
transformed file directories matches the number of input files
300-
provided to run_ab_transforms (i.e., the expected number of
301-
files that are transformed).
298+
Transmogrifier produces JSON files for records that need indexing, and TXT files for
299+
records that need deletion. Every run of Transmogrifier should produce one OR both of
300+
these. Some TIMDEX sources provide one file to Transmogrifier that contains both
301+
records to index and delete, and others provide separate files for each.
302+
303+
The net effect for validation is that, given an input file, we should expect to see
304+
1+ files in the A and B output for that input file, ignoring if it's records to index
305+
or delete.
302306
"""
303-
if any(
304-
len(transformed_files) != input_files_count
305-
for transformed_files in ab_transformed_file_lists
306-
):
307-
raise OutputValidationError( # noqa: TRY003
308-
"At least one or more transformed JSON file(s) are missing. "
309-
f"Expecting {input_files_count} transformed JSON file(s) per A/B version. "
310-
"Check the transformed file directories."
311-
)
307+
for input_file in input_files:
308+
file_parts = parse_timdex_filename(input_file)
309+
logger.debug(f"Validating output for input file root: {file_parts}")
310+
311+
file_found = False
312+
for version_files in ab_transformed_file_lists:
313+
for version_file in version_files:
314+
if (
315+
file_parts["source"] in version_file # type: ignore[operator]
316+
and file_parts["run-date"] in version_file # type: ignore[operator]
317+
and file_parts["run-type"] in version_file # type: ignore[operator]
318+
and (not file_parts["index"] or file_parts["index"] in version_file)
319+
):
320+
file_found = True
321+
break
322+
323+
if not file_found:
324+
raise OutputValidationError( # noqa: TRY003
325+
f"Transmogrifier output was not found for input file '{input_file}'"
326+
)
312327

313328

314329
def get_transformed_filename(filename_details: dict) -> str:
@@ -318,13 +333,13 @@ def get_transformed_filename(filename_details: dict) -> str:
318333
index=f"_{sequence}" if (sequence := filename_details["index"]) else "",
319334
)
320335
output_filename = (
321-
"{source}-{run_date}-{run_type}-{stage}-records-to-index{index}.{file_type}"
336+
"{source}-{run_date}-{run_type}-{stage}-records-to-{action}{index}.json"
322337
)
323338
return output_filename.format(
324339
source=filename_details["source"],
325340
run_date=filename_details["run-date"],
326341
run_type=filename_details["run-type"],
327342
stage=filename_details["stage"],
328343
index=filename_details["index"],
329-
file_type=filename_details["file_type"],
344+
action=filename_details["action"],
330345
)

0 commit comments

Comments
 (0)