-
Notifications
You must be signed in to change notification settings - Fork 1
Parallelizes Data.content_hash() for large datasets
#86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,9 +7,32 @@ | |
| import hashlib | ||
| import os | ||
| import posixpath | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from functools import partial | ||
|
|
||
| from absl import logging | ||
|
|
||
| # Directories with more files than this threshold are hashed in parallel | ||
| # using a thread pool. Below this, sequential hashing avoids pool overhead. | ||
| _PARALLEL_HASH_THRESHOLD = 16 | ||
| _HASH_BATCH_SIZE = 512 | ||
|
|
||
|
|
||
| def _hash_single_file(fpath: str, relpath: str) -> bytes: | ||
| """SHA-256 of relpath + \\0 + file contents. Returns raw 32-byte digest.""" | ||
| h = hashlib.sha256() | ||
| h.update(relpath.encode("utf-8")) | ||
| h.update(b"\0") | ||
| with open(fpath, "rb") as f: | ||
| for chunk in iter(partial(f.read, 65536), b""): | ||
| h.update(chunk) | ||
| return h.digest() | ||
|
|
||
|
|
||
| def _hash_file_batch(batch: list[tuple[str, str]]) -> list[bytes]: | ||
| """Hash a batch of (relpath, fpath) pairs. Returns list of 32-byte digests.""" | ||
| return [_hash_single_file(fpath, relpath) for relpath, fpath in batch] | ||
|
|
||
|
|
||
| class Data: | ||
| """A reference to data that should be available on the remote pod. | ||
|
|
@@ -78,6 +101,10 @@ def is_dir(self) -> bool: | |
| def content_hash(self) -> str: | ||
| """SHA-256 hash of all file contents, sorted by relative path. | ||
|
|
||
| Uses two-level hashing for parallelism: each file is hashed | ||
| independently (SHA-256 of relpath + contents), then per-file | ||
| digests are combined in sorted order into a final hash. | ||
|
|
||
| Includes a type prefix ("dir:" or "file:") to prevent collisions | ||
| between a single file and a directory containing only that file. | ||
|
|
||
|
|
@@ -88,34 +115,53 @@ def content_hash(self) -> str: | |
| """ | ||
| if self.is_gcs: | ||
| raise ValueError("Cannot compute content hash for GCS URI") | ||
| if os.path.isdir(self._resolved_path): | ||
| return self._content_hash_dir() | ||
| return self._content_hash_file() | ||
|
|
||
| def _content_hash_file(self) -> str: | ||
| h = hashlib.sha256() | ||
| if os.path.isdir(self._resolved_path): | ||
| h.update(b"dir:") | ||
| for root, dirs, files in os.walk(self._resolved_path, followlinks=False): | ||
| dirs.sort() | ||
| for fname in sorted(files): | ||
| fpath = os.path.join(root, fname) | ||
| relpath = os.path.relpath(fpath, self._resolved_path) | ||
| h.update(relpath.encode("utf-8")) | ||
| h.update(b"\0") | ||
| with open(fpath, "rb") as f: | ||
| while True: | ||
| chunk = f.read(65536) # 64 KB chunks | ||
| if not chunk: | ||
| break | ||
| h.update(chunk) | ||
| h.update(b"\0") | ||
| h.update(b"file:") | ||
| h.update( | ||
| _hash_single_file( | ||
| self._resolved_path, | ||
| os.path.basename(self._resolved_path), | ||
| ) | ||
| ) | ||
| return h.hexdigest() | ||
|
|
||
| def _content_hash_dir(self) -> str: | ||
| # Enumerate all files. Walk in filesystem order (better disk | ||
| # locality) and sort once at the end for determinism. | ||
| file_list = [] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Building a full file_list array covering every file in memory before chunking triggers the same RAM-saturation pattern found in other parts of the client/runner. For massive datasets (e.g., 20M+ files), iterating over the structure can cause an Out-Of-Memory (OOM) crash before the pool map even launches. |
||
| for root, _dirs, files in os.walk(self._resolved_path, followlinks=False): | ||
| for fname in files: | ||
| fpath = os.path.join(root, fname) | ||
| relpath = os.path.relpath(fpath, self._resolved_path) | ||
| file_list.append((relpath, fpath)) | ||
| file_list.sort() | ||
|
|
||
| # Hash each file independently. Use a thread pool for large | ||
| # directories to parallelize I/O-bound reads. Work is batched | ||
| # to avoid creating one Future per file. | ||
| if len(file_list) <= _PARALLEL_HASH_THRESHOLD: | ||
| digests = _hash_file_batch(file_list) | ||
| else: | ||
| h.update(b"file:") | ||
| h.update(os.path.basename(self._resolved_path).encode("utf-8")) | ||
| h.update(b"\0") | ||
| with open(self._resolved_path, "rb") as f: | ||
| while True: | ||
| chunk = f.read(65536) | ||
| if not chunk: | ||
| break | ||
| h.update(chunk) | ||
| batches = [ | ||
| file_list[i : i + _HASH_BATCH_SIZE] | ||
| for i in range(0, len(file_list), _HASH_BATCH_SIZE) | ||
| ] | ||
| max_workers = min(32, (os.cpu_count() or 4) + 4) | ||
| with ThreadPoolExecutor(max_workers=max_workers) as pool: | ||
| digests = [] | ||
| for batch_digests in pool.map(_hash_file_batch, batches): | ||
| digests.extend(batch_digests) | ||
JyotinderSingh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Combine per-file digests (each exactly 32 bytes) into final hash. | ||
| h = hashlib.sha256() | ||
| h.update(b"dir:") | ||
| for digest in digests: | ||
| h.update(digest) | ||
| return h.hexdigest() | ||
|
|
||
| def __repr__(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you leave a comment as to why it's chunked this way
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just aligned with the default chunk size for python stdlib utilities. 64-256kb is also the standard nvme page size.
shutil.copyfileobjuses a 64kb chunk, whilehashlib.file_digestuses a 256kb chunk