diff --git a/keras_remote/data/data.py b/keras_remote/data/data.py index e09502a..e121b50 100644 --- a/keras_remote/data/data.py +++ b/keras_remote/data/data.py @@ -7,9 +7,33 @@ import hashlib import os import posixpath +from concurrent.futures import ThreadPoolExecutor +from functools import partial from absl import logging +# Directories with more files than this threshold are hashed in parallel +# using a thread pool. Below this, sequential hashing avoids pool overhead. +_PARALLEL_HASH_THRESHOLD = 16 +_HASH_BATCH_SIZE = 512 + + +def _hash_single_file(fpath: str, relpath: str) -> bytes: + """SHA-256 of relpath + \\0 + file contents. Returns raw 32-byte digest.""" + h = hashlib.sha256() + h.update(relpath.encode("utf-8")) + h.update(b"\0") + # 256 KB: matches hashlib.file_digest's default buffer size. + with open(fpath, "rb") as f: + for chunk in iter(partial(f.read, 2**18), b""): + h.update(chunk) + return h.digest() + + +def _hash_file_batch(batch: list[tuple[str, str]]) -> list[bytes]: + """Hash a batch of (relpath, fpath) pairs. Returns list of 32-byte digests.""" + return [_hash_single_file(fpath, relpath) for relpath, fpath in batch] + class Data: """A reference to data that should be available on the remote pod. @@ -78,6 +102,10 @@ def is_dir(self) -> bool: def content_hash(self) -> str: """SHA-256 hash of all file contents, sorted by relative path. + Uses two-level hashing for parallelism: each file is hashed + independently (SHA-256 of relpath + contents), then per-file + digests are combined in sorted order into a final hash. + Includes a type prefix ("dir:" or "file:") to prevent collisions between a single file and a directory containing only that file. @@ -88,34 +116,53 @@ def content_hash(self) -> str: """ if self.is_gcs: raise ValueError("Cannot compute content hash for GCS URI") + if os.path.isdir(self._resolved_path): + return self._content_hash_dir() + return self._content_hash_file() + def _content_hash_file(self) -> str: h = hashlib.sha256() - if os.path.isdir(self._resolved_path): - h.update(b"dir:") - for root, dirs, files in os.walk(self._resolved_path, followlinks=False): - dirs.sort() - for fname in sorted(files): - fpath = os.path.join(root, fname) - relpath = os.path.relpath(fpath, self._resolved_path) - h.update(relpath.encode("utf-8")) - h.update(b"\0") - with open(fpath, "rb") as f: - while True: - chunk = f.read(65536) # 64 KB chunks - if not chunk: - break - h.update(chunk) - h.update(b"\0") + h.update(b"file:") + h.update( + _hash_single_file( + self._resolved_path, + os.path.basename(self._resolved_path), + ) + ) + return h.hexdigest() + + def _content_hash_dir(self) -> str: + # Enumerate all files. Walk in filesystem order (better disk + # locality) and sort once at the end for determinism. + file_list = [] + for root, _dirs, files in os.walk(self._resolved_path, followlinks=False): + for fname in files: + fpath = os.path.join(root, fname) + relpath = os.path.relpath(fpath, self._resolved_path) + file_list.append((relpath, fpath)) + file_list.sort() + + # Hash each file independently. Use a thread pool for large + # directories to parallelize I/O-bound reads. Work is batched + # to avoid creating one Future per file. + if len(file_list) <= _PARALLEL_HASH_THRESHOLD: + digests = _hash_file_batch(file_list) else: - h.update(b"file:") - h.update(os.path.basename(self._resolved_path).encode("utf-8")) - h.update(b"\0") - with open(self._resolved_path, "rb") as f: - while True: - chunk = f.read(65536) - if not chunk: - break - h.update(chunk) + batches = [ + file_list[i : i + _HASH_BATCH_SIZE] + for i in range(0, len(file_list), _HASH_BATCH_SIZE) + ] + max_workers = min(32, (os.cpu_count() or 4) + 4) + with ThreadPoolExecutor(max_workers=max_workers) as pool: + digests = [] + for batch_digests in pool.map(_hash_file_batch, batches): + digests.extend(batch_digests) + + # Combine per-file digests (each exactly 32 bytes) into final hash. + h = hashlib.sha256() + h.update(b"dir:") + for digest in digests: + h.update(digest) return h.hexdigest() def __repr__(self): diff --git a/keras_remote/data/data_test.py b/keras_remote/data/data_test.py index ecc16ab..2fffc9d 100644 --- a/keras_remote/data/data_test.py +++ b/keras_remote/data/data_test.py @@ -7,6 +7,7 @@ from absl.testing import absltest from keras_remote.data import Data, _make_data_ref, is_data_ref +from keras_remote.data.data import _PARALLEL_HASH_THRESHOLD def _make_temp_path(test_case): @@ -216,6 +217,34 @@ def test_path_included_in_hash(self): Data(str(d1)).content_hash(), Data(str(d2)).content_hash() ) + def test_parallel_determinism_many_files(self): + """Directory with many files exercises the thread pool path and + must still produce deterministic hashes.""" + tmp = _make_temp_path(self) + d = tmp / "large_dir" + d.mkdir() + num_files = _PARALLEL_HASH_THRESHOLD + 30 + for i in range(num_files): + (d / f"file_{i:04d}.txt").write_text(f"content_{i}") + + hashes = [Data(str(d)).content_hash() for _ in range(5)] + self.assertTrue(all(h == hashes[0] for h in hashes)) + + def test_parallel_threshold_boundary(self): + """Directories at and just above the threshold both produce valid + deterministic hashes.""" + tmp = _make_temp_path(self) + for count in (_PARALLEL_HASH_THRESHOLD, _PARALLEL_HASH_THRESHOLD + 1): + d = tmp / f"dir_{count}" + d.mkdir() + for i in range(count): + (d / f"f{i}.txt").write_text(f"data{i}") + + h1 = Data(str(d)).content_hash() + h2 = Data(str(d)).content_hash() + self.assertEqual(h1, h2) + self.assertEqual(len(h1), 64) + class TestMakeDataRef(absltest.TestCase): def test_basic_ref(self):