|
24 | 24 | import sqlite3 |
25 | 25 | import time |
26 | 26 | from collections import Counter |
| 27 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
27 | 28 | from dataclasses import dataclass |
28 | 29 | from pathlib import Path |
29 | 30 |
|
@@ -246,36 +247,50 @@ def choose_log_objects( |
246 | 247 | return chosen |
247 | 248 |
|
248 | 249 |
|
| 250 | +def _download_one(fs, entry: dict, target_dir: Path, index: int, total: int) -> Path: |
| 251 | + name = entry["name"] |
| 252 | + remote_size = entry.get("size") |
| 253 | + local_path = target_dir / Path(name).name |
| 254 | + cached_ok = local_path.exists() and remote_size is not None and local_path.stat().st_size == remote_size |
| 255 | + if cached_ok: |
| 256 | + logging.info(f" [{index}/{total}] {Path(name).name} already cached") |
| 257 | + return local_path |
| 258 | + if local_path.exists(): |
| 259 | + logging.info( |
| 260 | + f" [{index}/{total}] {Path(name).name} cached but size mismatch " |
| 261 | + f"(local={local_path.stat().st_size}, remote={remote_size}); re-downloading" |
| 262 | + ) |
| 263 | + else: |
| 264 | + logging.info(f" [{index}/{total}] Downloading {Path(name).name}") |
| 265 | + tmp_path = local_path.with_suffix(local_path.suffix + ".part") |
| 266 | + with fs.open(name, "rb") as src, tmp_path.open("wb") as dst: |
| 267 | + while True: |
| 268 | + chunk = src.read(8 * 1024 * 1024) |
| 269 | + if not chunk: |
| 270 | + break |
| 271 | + dst.write(chunk) |
| 272 | + tmp_path.replace(local_path) |
| 273 | + return local_path |
| 274 | + |
| 275 | + |
249 | 276 | def download_log_objects(remote_logs_dir: str, entries: list[dict], target_dir: Path) -> list[Path]: |
250 | 277 | fs, _ = fsspec.core.url_to_fs(remote_logs_dir) |
251 | 278 | target_dir.mkdir(parents=True, exist_ok=True) |
252 | 279 | logging.info(f"Downloading {len(entries)} parquet files to {target_dir}") |
253 | 280 |
|
254 | | - local_paths: list[Path] = [] |
255 | | - for i, entry in enumerate(entries, 1): |
256 | | - name = entry["name"] |
257 | | - remote_size = entry.get("size") |
258 | | - local_path = target_dir / Path(name).name |
259 | | - cached_ok = local_path.exists() and remote_size is not None and local_path.stat().st_size == remote_size |
260 | | - if cached_ok: |
261 | | - logging.info(f" [{i}/{len(entries)}] {Path(name).name} already cached") |
262 | | - else: |
263 | | - if local_path.exists(): |
264 | | - logging.info( |
265 | | - f" [{i}/{len(entries)}] {Path(name).name} cached but size mismatch " |
266 | | - f"(local={local_path.stat().st_size}, remote={remote_size}); re-downloading" |
267 | | - ) |
268 | | - else: |
269 | | - logging.info(f" [{i}/{len(entries)}] Downloading {Path(name).name}") |
270 | | - tmp_path = local_path.with_suffix(local_path.suffix + ".part") |
271 | | - with fs.open(name, "rb") as src, tmp_path.open("wb") as dst: |
272 | | - while True: |
273 | | - chunk = src.read(8 * 1024 * 1024) |
274 | | - if not chunk: |
275 | | - break |
276 | | - dst.write(chunk) |
277 | | - tmp_path.replace(local_path) |
278 | | - local_paths.append(local_path) |
| 281 | + index_map = {id(entry): i for i, entry in enumerate(entries, 1)} |
| 282 | + total = len(entries) |
| 283 | + results: dict[int, Path] = {} |
| 284 | + |
| 285 | + with ThreadPoolExecutor(max_workers=16) as pool: |
| 286 | + futures = { |
| 287 | + pool.submit(_download_one, fs, entry, target_dir, index_map[id(entry)], total): entry for entry in entries |
| 288 | + } |
| 289 | + for fut in as_completed(futures): |
| 290 | + entry = futures[fut] |
| 291 | + results[index_map[id(entry)]] = fut.result() |
| 292 | + |
| 293 | + local_paths = [results[i] for i in range(1, total + 1)] |
279 | 294 | logging.info(f"Finished downloading {len(local_paths)} parquet files") |
280 | 295 | return local_paths |
281 | 296 |
|
|
0 commit comments