Skip to content

Commit e3032c6

Browse files
Parallelizes Data.content_hash() for large datasets
1 parent 83f83c6 commit e3032c6

File tree

2 files changed

+100
-25
lines changed

2 files changed

+100
-25
lines changed

keras_remote/data/data.py

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,32 @@
77
import hashlib
88
import os
99
import posixpath
10+
from concurrent.futures import ThreadPoolExecutor
11+
from functools import partial
1012

1113
from absl import logging
1214

15+
# Directories with more files than this threshold are hashed in parallel
16+
# using a thread pool. Below this, sequential hashing avoids pool overhead.
17+
_PARALLEL_HASH_THRESHOLD = 16
18+
_HASH_BATCH_SIZE = 512
19+
20+
21+
def _hash_single_file(fpath: str, relpath: str) -> bytes:
22+
"""SHA-256 of relpath + \\0 + file contents. Returns raw 32-byte digest."""
23+
h = hashlib.sha256()
24+
h.update(relpath.encode("utf-8"))
25+
h.update(b"\0")
26+
with open(fpath, "rb") as f:
27+
for chunk in iter(partial(f.read, 65536), b""):
28+
h.update(chunk)
29+
return h.digest()
30+
31+
32+
def _hash_file_batch(batch: list[tuple[str, str]]) -> list[bytes]:
33+
"""Hash a batch of (relpath, fpath) pairs. Returns list of 32-byte digests."""
34+
return [_hash_single_file(fpath, relpath) for relpath, fpath in batch]
35+
1336

1437
class Data:
1538
"""A reference to data that should be available on the remote pod.
@@ -78,6 +101,10 @@ def is_dir(self) -> bool:
78101
def content_hash(self) -> str:
79102
"""SHA-256 hash of all file contents, sorted by relative path.
80103
104+
Uses two-level hashing for parallelism: each file is hashed
105+
independently (SHA-256 of relpath + contents), then per-file
106+
digests are combined in sorted order into a final hash.
107+
81108
Includes a type prefix ("dir:" or "file:") to prevent collisions
82109
between a single file and a directory containing only that file.
83110
@@ -88,34 +115,53 @@ def content_hash(self) -> str:
88115
"""
89116
if self.is_gcs:
90117
raise ValueError("Cannot compute content hash for GCS URI")
118+
if os.path.isdir(self._resolved_path):
119+
return self._content_hash_dir()
120+
return self._content_hash_file()
91121

122+
def _content_hash_file(self) -> str:
92123
h = hashlib.sha256()
93-
if os.path.isdir(self._resolved_path):
94-
h.update(b"dir:")
95-
for root, dirs, files in os.walk(self._resolved_path, followlinks=False):
96-
dirs.sort()
97-
for fname in sorted(files):
98-
fpath = os.path.join(root, fname)
99-
relpath = os.path.relpath(fpath, self._resolved_path)
100-
h.update(relpath.encode("utf-8"))
101-
h.update(b"\0")
102-
with open(fpath, "rb") as f:
103-
while True:
104-
chunk = f.read(65536) # 64 KB chunks
105-
if not chunk:
106-
break
107-
h.update(chunk)
108-
h.update(b"\0")
124+
h.update(b"file:")
125+
h.update(
126+
_hash_single_file(
127+
self._resolved_path,
128+
os.path.basename(self._resolved_path),
129+
)
130+
)
131+
return h.hexdigest()
132+
133+
def _content_hash_dir(self) -> str:
134+
# Enumerate all files. Walk in filesystem order (better disk
135+
# locality) and sort once at the end for determinism.
136+
file_list = []
137+
for root, _dirs, files in os.walk(self._resolved_path, followlinks=False):
138+
for fname in files:
139+
fpath = os.path.join(root, fname)
140+
relpath = os.path.relpath(fpath, self._resolved_path)
141+
file_list.append((relpath, fpath))
142+
file_list.sort()
143+
144+
# Hash each file independently. Use a thread pool for large
145+
# directories to parallelize I/O-bound reads. Work is batched
146+
# to avoid creating one Future per file.
147+
if len(file_list) <= _PARALLEL_HASH_THRESHOLD:
148+
digests = _hash_file_batch(file_list)
109149
else:
110-
h.update(b"file:")
111-
h.update(os.path.basename(self._resolved_path).encode("utf-8"))
112-
h.update(b"\0")
113-
with open(self._resolved_path, "rb") as f:
114-
while True:
115-
chunk = f.read(65536)
116-
if not chunk:
117-
break
118-
h.update(chunk)
150+
batches = [
151+
file_list[i : i + _HASH_BATCH_SIZE]
152+
for i in range(0, len(file_list), _HASH_BATCH_SIZE)
153+
]
154+
max_workers = min(32, (os.cpu_count() or 4) + 4)
155+
with ThreadPoolExecutor(max_workers=max_workers) as pool:
156+
digests = []
157+
for batch_digests in pool.map(_hash_file_batch, batches):
158+
digests.extend(batch_digests)
159+
160+
# Combine per-file digests (each exactly 32 bytes) into final hash.
161+
h = hashlib.sha256()
162+
h.update(b"dir:")
163+
for digest in digests:
164+
h.update(digest)
119165
return h.hexdigest()
120166

121167
def __repr__(self):

keras_remote/data/data_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from absl.testing import absltest
88

99
from keras_remote.data import Data, _make_data_ref, is_data_ref
10+
from keras_remote.data.data import _PARALLEL_HASH_THRESHOLD
1011

1112

1213
def _make_temp_path(test_case):
@@ -216,6 +217,34 @@ def test_path_included_in_hash(self):
216217
Data(str(d1)).content_hash(), Data(str(d2)).content_hash()
217218
)
218219

220+
def test_parallel_determinism_many_files(self):
221+
"""Directory with many files exercises the thread pool path and
222+
must still produce deterministic hashes."""
223+
tmp = _make_temp_path(self)
224+
d = tmp / "large_dir"
225+
d.mkdir()
226+
num_files = _PARALLEL_HASH_THRESHOLD + 30
227+
for i in range(num_files):
228+
(d / f"file_{i:04d}.txt").write_text(f"content_{i}")
229+
230+
hashes = [Data(str(d)).content_hash() for _ in range(5)]
231+
self.assertTrue(all(h == hashes[0] for h in hashes))
232+
233+
def test_parallel_threshold_boundary(self):
234+
"""Directories at and just above the threshold both produce valid
235+
deterministic hashes."""
236+
tmp = _make_temp_path(self)
237+
for count in (_PARALLEL_HASH_THRESHOLD, _PARALLEL_HASH_THRESHOLD + 1):
238+
d = tmp / f"dir_{count}"
239+
d.mkdir()
240+
for i in range(count):
241+
(d / f"f{i}.txt").write_text(f"data{i}")
242+
243+
h1 = Data(str(d)).content_hash()
244+
h2 = Data(str(d)).content_hash()
245+
self.assertEqual(h1, h2)
246+
self.assertEqual(len(h1), 64)
247+
219248

220249
class TestMakeDataRef(absltest.TestCase):
221250
def test_basic_ref(self):

0 commit comments

Comments
 (0)