Skip to content

Commit e436d53

Browse files
authored
Merge pull request #48 from MITLibraries/TIMX-383-pipeline-tweaks-large-runs
TIMX 383 - pipeline tweaks for large runs
2 parents 16ecf73 + a26c0a8 commit e436d53

34 files changed

+702
-23485
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ WEBAPP_PORT=# port for flask webapp
137137
TRANSMOGRIFIER_MAX_WORKERS=# max number of Transmogrifier containers to run in parallel; default is 6
138138
TRANSMOGRIFIER_TIMEOUT=# timeout for a single Transmogrifier container; default is 5 hours
139139
TIMDEX_BUCKET=# when using CLI command 'timdex-sources-csv', this is required to know what TIMDEX bucket to use
140+
PRESERVE_ARTIFACTS=# if 'true', intermediate artifacts like transformed files, collated records, etc., will not be automatically removed
141+
ALLOW_FAILED_TRANSMOGRIFIER_CONTAINERS=# if 'true' (default), the run will continue even if some Transmogrifier containers failed to complete successfully
140142
```
141143

142144
## CLI commands

abdiff/cli.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
import shutil
34
from datetime import timedelta
45
from itertools import chain
56
from time import perf_counter
@@ -14,6 +15,7 @@
1415
calc_ab_diffs,
1516
calc_ab_metrics,
1617
collate_ab_transforms,
18+
create_final_records,
1719
download_input_files,
1820
init_run,
1921
run_ab_transforms,
@@ -25,6 +27,8 @@
2527

2628
logger = logging.getLogger(__name__)
2729

30+
CONFIG = Config()
31+
2832

2933
@click.group(context_settings={"help_option_names": ["-h", "--help"]})
3034
@click.option(
@@ -181,19 +185,31 @@ def run_diff(
181185
input_files=input_files_list,
182186
use_local_s3=download_files,
183187
)
188+
184189
collated_dataset_path = collate_ab_transforms(
185190
run_directory=run_directory,
186191
ab_transformed_file_lists=ab_transformed_file_lists,
187192
)
193+
188194
diffs_dataset_path = calc_ab_diffs(
189195
run_directory=run_directory,
190196
collated_dataset_path=collated_dataset_path,
191197
)
192-
calc_ab_metrics(
198+
199+
if not CONFIG.preserve_artifacts:
200+
shutil.rmtree(collated_dataset_path)
201+
202+
metrics_dataset_path = calc_ab_metrics(
193203
run_directory=run_directory,
194204
diffs_dataset_path=diffs_dataset_path,
195205
)
196206

207+
create_final_records(run_directory, diffs_dataset_path, metrics_dataset_path)
208+
209+
if not CONFIG.preserve_artifacts:
210+
shutil.rmtree(diffs_dataset_path)
211+
shutil.rmtree(metrics_dataset_path)
212+
197213

198214
@main.command()
199215
@click.option(

abdiff/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class Config:
2121
"TRANSMOGRIFIER_MAX_WORKERS",
2222
"TRANSMOGRIFIER_TIMEOUT",
2323
"TIMDEX_BUCKET",
24+
"PRESERVE_ARTIFACTS",
25+
"ALLOW_FAILED_TRANSMOGRIFIER_CONTAINERS",
2426
)
2527

2628
def __getattr__(self, name: str) -> Any: # noqa: ANN401
@@ -81,6 +83,19 @@ def active_timdex_sources(self) -> list[str]:
8183
"researchdatabases",
8284
]
8385

86+
@property
87+
def preserve_artifacts(self) -> bool:
88+
return bool(
89+
self.PRESERVE_ARTIFACTS and self.PRESERVE_ARTIFACTS.strip().lower() == "true"
90+
)
91+
92+
@property
93+
def allow_failed_transmogrifier_containers(self) -> bool:
94+
return bool(
95+
self.ALLOW_FAILED_TRANSMOGRIFIER_CONTAINERS
96+
and self.ALLOW_FAILED_TRANSMOGRIFIER_CONTAINERS.strip().lower() == "true"
97+
)
98+
8499

85100
def configure_logger(logger: logging.Logger, *, verbose: bool) -> str:
86101
if verbose:

abdiff/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from abdiff.core.calc_ab_diffs import calc_ab_diffs
88
from abdiff.core.calc_ab_metrics import calc_ab_metrics
99
from abdiff.core.collate_ab_transforms import collate_ab_transforms
10+
from abdiff.core.create_final_records import create_final_records
1011
from abdiff.core.init_job import init_job
1112
from abdiff.core.init_run import init_run
1213
from abdiff.core.run_ab_transforms import run_ab_transforms
@@ -21,4 +22,5 @@
2122
"collate_ab_transforms",
2223
"calc_ab_diffs",
2324
"calc_ab_metrics",
25+
"create_final_records",
2426
]

abdiff/core/calc_ab_diffs.py

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import concurrent.futures
12
import json
23
import logging
34
import time
@@ -12,9 +13,9 @@
1213

1314
logger = logging.getLogger(__name__)
1415

15-
READ_BATCH_SIZE = 1_000
16-
WRITE_MAX_ROW_GROUP_SIZE = 1_000
16+
READ_BATCH_SIZE = 10_000
1717
WRITE_MAX_ROWS_PER_FILE = 100_000
18+
MAX_PARALLEL_WORKERS = 6
1819

1920
DIFFS_DATASET_OUTPUT_SCHEMA = pa.schema(
2021
(
@@ -52,29 +53,55 @@ def calc_ab_diffs(run_directory: str, collated_dataset_path: str) -> str:
5253
return str(diffs_dataset)
5354

5455

56+
def process_batch(batch: pa.RecordBatch) -> pa.RecordBatch:
57+
"""Parallel worker for calculating record diffs for a batch.
58+
59+
The pyarrow RecordBatch is converted into a pandas dataframe, a diff is calculated via
60+
DeepDiff for each record in the batch, and this is converted back to a pyarrow
61+
RecordBatch for returning.
62+
"""
63+
df = batch.to_pandas() # noqa: PD901
64+
diff_results = df.apply(
65+
lambda row: calc_record_diff(row["record_a"], row["record_b"]), axis=1
66+
)
67+
df["ab_diff"] = diff_results.apply(lambda x: x[0])
68+
df["modified_timdex_fields"] = diff_results.apply(
69+
lambda x: list(x[1]) if x[1] else []
70+
)
71+
df["has_diff"] = diff_results.apply(lambda x: x[2])
72+
return pa.RecordBatch.from_pandas(df) # type: ignore[attr-defined]
73+
74+
5575
def get_diffed_batches_iter(
5676
collated_dataset: ds.Dataset,
5777
batch_size: int = READ_BATCH_SIZE,
78+
max_parallel_processes: int = MAX_PARALLEL_WORKERS,
5879
) -> Generator[pa.RecordBatch, None, None]:
59-
"""Yield pyarrow record batches with diff calculated for records in batch."""
80+
"""Yield pyarrow record batches with diff calculated for each record.
81+
82+
This work is performed in parallel, leveraging CPU cores to calculate the diffs and
83+
yield batches for writing to the "diffs" dataset.
84+
"""
6085
batches_iter = collated_dataset.to_batches(batch_size=batch_size)
61-
for i, batch in enumerate(batches_iter):
62-
logger.info(f"Calculating AB diff for batch: {i}")
6386

64-
# convert batch to pandas dataframe and calc values for new columns
65-
df = batch.to_pandas() # noqa: PD901
87+
with concurrent.futures.ProcessPoolExecutor(
88+
max_workers=max_parallel_processes + 1
89+
) as executor:
90+
pending_futures = []
91+
for batch_count, batch in enumerate(batches_iter):
92+
future = executor.submit(process_batch, batch)
93+
pending_futures.append((batch_count, future))
6694

67-
# calculate all diffs and unpack into separate columns
68-
diff_results = df.apply(
69-
lambda row: calc_record_diff(row["record_a"], row["record_b"]), axis=1
70-
)
71-
df["ab_diff"] = diff_results.apply(lambda x: x[0])
72-
df["modified_timdex_fields"] = diff_results.apply(
73-
lambda x: list(x[1]) if x[1] else []
74-
)
75-
df["has_diff"] = diff_results.apply(lambda x: x[2])
95+
if len(pending_futures) >= max_parallel_processes:
96+
idx, completed_future = pending_futures.pop(0)
97+
result = completed_future.result()
98+
logger.info(f"Yielding diffed batch: {idx}")
99+
yield result
76100

77-
yield pa.RecordBatch.from_pandas(df) # type: ignore[attr-defined]
101+
for idx, future in pending_futures:
102+
result = future.result()
103+
logger.info(f"Yielding diffed batch: {idx}")
104+
yield result
78105

79106

80107
def calc_record_diff(
@@ -83,32 +110,53 @@ def calc_record_diff(
83110
*,
84111
ignore_order: bool = True,
85112
report_repetition: bool = True,
86-
) -> tuple[str | None, list[str] | None, bool]:
113+
) -> tuple[str, set[str], bool]:
87114
"""Calculate diff from two JSON byte strings.
88115
89116
The DeepDiff library has the property 'affected_root_keys' on the produced diff object
90117
that is very useful for our purposes. At this time, we simply want to know if
91118
anything about a particular root level TIMDEX field (e.g. 'dates' or 'title') has
92-
changed which this method provides explicitly. We also serialize the full diff to
93-
JSON via the to_json() method for storage and possible further analysis.
119+
changed which this method provides explicitly. In the unlikely case that the records
120+
share ZERO keys, a special case is handled where the modified root paths are returned
121+
as only ['root'], in which case we get a combined set keys from both records, which is
122+
effectively the modified root fields.
123+
124+
We also serialize the full diff to JSON via the to_json() method for storage and
125+
possible further analysis.
94126
95-
This method returns a tuple:
127+
Returns tuple(ab_diff, modified_timdex_fields, has_diff):
96128
- ab_diff: [str] - full diff as JSON
97129
- modified_timdex_fields: list[str] - list of modified root keys (TIMDEX fields)
98130
- has_diff: bool - True/False if any diff present
99131
"""
100-
if record_a is None or record_b is None:
101-
return None, None, False
132+
# Replace None with empty dict
133+
record_a = record_a or {}
134+
record_b = record_b or {}
135+
136+
# Parse JSON strings or bytes into dictionaries
137+
if isinstance(record_a, (str | bytes)):
138+
record_a = json.loads(record_a)
139+
if isinstance(record_b, (str | bytes)):
140+
record_b = json.loads(record_b)
102141

103142
diff = DeepDiff(
104-
json.loads(record_a) if isinstance(record_a, str | bytes) else record_a,
105-
json.loads(record_b) if isinstance(record_b, str | bytes) else record_b,
143+
record_a,
144+
record_b,
106145
ignore_order=ignore_order,
107146
report_repetition=report_repetition,
108147
)
109148

110149
ab_diff = diff.to_json()
111-
modified_timdex_fields = diff.affected_root_keys
150+
151+
# get modified root fields, handling edge cases
152+
if diff.affected_paths != ["root"]:
153+
modified_timdex_fields = diff.affected_root_keys
154+
else:
155+
modified_timdex_fields = set()
156+
for record in [record_a, record_b]:
157+
if isinstance(record, dict):
158+
modified_timdex_fields.update(record.keys())
159+
112160
has_diff = bool(modified_timdex_fields)
113161

114162
return ab_diff, modified_timdex_fields, has_diff

abdiff/core/calc_ab_metrics.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
def calc_ab_metrics(
2121
run_directory: str,
2222
diffs_dataset_path: str,
23-
) -> dict:
23+
) -> str:
2424

25-
os.makedirs(Path(run_directory) / "metrics", exist_ok=True)
25+
metrics_dataset = Path(run_directory) / "metrics"
26+
os.makedirs(metrics_dataset, exist_ok=True)
2627

2728
# build field diffs dataframe
2829
field_matrix_dataset_filepath = create_record_diff_matrix_dataset(
@@ -37,7 +38,7 @@ def calc_ab_metrics(
3738
run_directory=run_directory, new_data={"metrics": metrics_data}
3839
)
3940

40-
return metrics_data
41+
return str(metrics_dataset)
4142

4243

4344
def create_record_diff_matrix_dataset(

0 commit comments

Comments
 (0)