Skip to content

Commit 271b37b

Browse files
committed
[ops] Parallelize cross_region log downloads with ThreadPoolExecutor
Download parquet log files using up to 16 concurrent threads instead of sequentially. Large cross-region analyses with many log files were bottlenecked on single-threaded GCS downloads.
1 parent e233ba1 commit 271b37b

1 file changed

Lines changed: 40 additions & 25 deletions

File tree

scripts/ops/cross_region.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import sqlite3
2525
import time
2626
from collections import Counter
27+
from concurrent.futures import ThreadPoolExecutor, as_completed
2728
from dataclasses import dataclass
2829
from pathlib import Path
2930

@@ -246,36 +247,50 @@ def choose_log_objects(
246247
return chosen
247248

248249

250+
def _download_one(fs, entry: dict, target_dir: Path, index: int, total: int) -> Path:
251+
name = entry["name"]
252+
remote_size = entry.get("size")
253+
local_path = target_dir / Path(name).name
254+
cached_ok = local_path.exists() and remote_size is not None and local_path.stat().st_size == remote_size
255+
if cached_ok:
256+
logging.info(f" [{index}/{total}] {Path(name).name} already cached")
257+
return local_path
258+
if local_path.exists():
259+
logging.info(
260+
f" [{index}/{total}] {Path(name).name} cached but size mismatch "
261+
f"(local={local_path.stat().st_size}, remote={remote_size}); re-downloading"
262+
)
263+
else:
264+
logging.info(f" [{index}/{total}] Downloading {Path(name).name}")
265+
tmp_path = local_path.with_suffix(local_path.suffix + ".part")
266+
with fs.open(name, "rb") as src, tmp_path.open("wb") as dst:
267+
while True:
268+
chunk = src.read(8 * 1024 * 1024)
269+
if not chunk:
270+
break
271+
dst.write(chunk)
272+
tmp_path.replace(local_path)
273+
return local_path
274+
275+
249276
def download_log_objects(remote_logs_dir: str, entries: list[dict], target_dir: Path) -> list[Path]:
250277
fs, _ = fsspec.core.url_to_fs(remote_logs_dir)
251278
target_dir.mkdir(parents=True, exist_ok=True)
252279
logging.info(f"Downloading {len(entries)} parquet files to {target_dir}")
253280

254-
local_paths: list[Path] = []
255-
for i, entry in enumerate(entries, 1):
256-
name = entry["name"]
257-
remote_size = entry.get("size")
258-
local_path = target_dir / Path(name).name
259-
cached_ok = local_path.exists() and remote_size is not None and local_path.stat().st_size == remote_size
260-
if cached_ok:
261-
logging.info(f" [{i}/{len(entries)}] {Path(name).name} already cached")
262-
else:
263-
if local_path.exists():
264-
logging.info(
265-
f" [{i}/{len(entries)}] {Path(name).name} cached but size mismatch "
266-
f"(local={local_path.stat().st_size}, remote={remote_size}); re-downloading"
267-
)
268-
else:
269-
logging.info(f" [{i}/{len(entries)}] Downloading {Path(name).name}")
270-
tmp_path = local_path.with_suffix(local_path.suffix + ".part")
271-
with fs.open(name, "rb") as src, tmp_path.open("wb") as dst:
272-
while True:
273-
chunk = src.read(8 * 1024 * 1024)
274-
if not chunk:
275-
break
276-
dst.write(chunk)
277-
tmp_path.replace(local_path)
278-
local_paths.append(local_path)
281+
index_map = {id(entry): i for i, entry in enumerate(entries, 1)}
282+
total = len(entries)
283+
results: dict[int, Path] = {}
284+
285+
with ThreadPoolExecutor(max_workers=16) as pool:
286+
futures = {
287+
pool.submit(_download_one, fs, entry, target_dir, index_map[id(entry)], total): entry for entry in entries
288+
}
289+
for fut in as_completed(futures):
290+
entry = futures[fut]
291+
results[index_map[id(entry)]] = fut.result()
292+
293+
local_paths = [results[i] for i in range(1, total + 1)]
279294
logging.info(f"Finished downloading {len(local_paths)} parquet files")
280295
return local_paths
281296

0 commit comments

Comments
 (0)