Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 71 additions & 25 deletions keras_remote/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,32 @@
import hashlib
import os
import posixpath
from concurrent.futures import ThreadPoolExecutor
from functools import partial

from absl import logging

# Directories with more files than this threshold are hashed in parallel
# using a thread pool. Below this, sequential hashing avoids pool overhead.
_PARALLEL_HASH_THRESHOLD = 16
_HASH_BATCH_SIZE = 512


def _hash_single_file(fpath: str, relpath: str) -> bytes:
"""SHA-256 of relpath + \\0 + file contents. Returns raw 32-byte digest."""
h = hashlib.sha256()
h.update(relpath.encode("utf-8"))
h.update(b"\0")
with open(fpath, "rb") as f:
for chunk in iter(partial(f.read, 65536), b""):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment as to why it's chunked this way

Copy link
Collaborator Author

@JyotinderSingh JyotinderSingh Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just aligned with the default chunk size for python stdlib utilities. 64-256kb is also the standard nvme page size.

shutil.copyfileobj uses a 64kb chunk, while hashlib.file_digest uses a 256kb chunk

h.update(chunk)
return h.digest()


def _hash_file_batch(batch: list[tuple[str, str]]) -> list[bytes]:
"""Hash a batch of (relpath, fpath) pairs. Returns list of 32-byte digests."""
return [_hash_single_file(fpath, relpath) for relpath, fpath in batch]


class Data:
"""A reference to data that should be available on the remote pod.
Expand Down Expand Up @@ -78,6 +101,10 @@ def is_dir(self) -> bool:
def content_hash(self) -> str:
"""SHA-256 hash of all file contents, sorted by relative path.

Uses two-level hashing for parallelism: each file is hashed
independently (SHA-256 of relpath + contents), then per-file
digests are combined in sorted order into a final hash.

Includes a type prefix ("dir:" or "file:") to prevent collisions
between a single file and a directory containing only that file.

Expand All @@ -88,34 +115,53 @@ def content_hash(self) -> str:
"""
if self.is_gcs:
raise ValueError("Cannot compute content hash for GCS URI")
if os.path.isdir(self._resolved_path):
return self._content_hash_dir()
return self._content_hash_file()

def _content_hash_file(self) -> str:
h = hashlib.sha256()
if os.path.isdir(self._resolved_path):
h.update(b"dir:")
for root, dirs, files in os.walk(self._resolved_path, followlinks=False):
dirs.sort()
for fname in sorted(files):
fpath = os.path.join(root, fname)
relpath = os.path.relpath(fpath, self._resolved_path)
h.update(relpath.encode("utf-8"))
h.update(b"\0")
with open(fpath, "rb") as f:
while True:
chunk = f.read(65536) # 64 KB chunks
if not chunk:
break
h.update(chunk)
h.update(b"\0")
h.update(b"file:")
h.update(
_hash_single_file(
self._resolved_path,
os.path.basename(self._resolved_path),
)
)
return h.hexdigest()

def _content_hash_dir(self) -> str:
# Enumerate all files. Walk in filesystem order (better disk
# locality) and sort once at the end for determinism.
file_list = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building a full file_list array covering every file in memory before chunking triggers the same RAM-saturation pattern found in other parts of the client/runner. For massive datasets (e.g., 20M+ files), iterating over the structure can cause an Out-Of-Memory (OOM) crash before the pool map even launches.

for root, _dirs, files in os.walk(self._resolved_path, followlinks=False):
for fname in files:
fpath = os.path.join(root, fname)
relpath = os.path.relpath(fpath, self._resolved_path)
file_list.append((relpath, fpath))
file_list.sort()

# Hash each file independently. Use a thread pool for large
# directories to parallelize I/O-bound reads. Work is batched
# to avoid creating one Future per file.
if len(file_list) <= _PARALLEL_HASH_THRESHOLD:
digests = _hash_file_batch(file_list)
else:
h.update(b"file:")
h.update(os.path.basename(self._resolved_path).encode("utf-8"))
h.update(b"\0")
with open(self._resolved_path, "rb") as f:
while True:
chunk = f.read(65536)
if not chunk:
break
h.update(chunk)
batches = [
file_list[i : i + _HASH_BATCH_SIZE]
for i in range(0, len(file_list), _HASH_BATCH_SIZE)
]
max_workers = min(32, (os.cpu_count() or 4) + 4)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
digests = []
for batch_digests in pool.map(_hash_file_batch, batches):
digests.extend(batch_digests)

# Combine per-file digests (each exactly 32 bytes) into final hash.
h = hashlib.sha256()
h.update(b"dir:")
for digest in digests:
h.update(digest)
return h.hexdigest()

def __repr__(self):
Expand Down
29 changes: 29 additions & 0 deletions keras_remote/data/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from absl.testing import absltest

from keras_remote.data import Data, _make_data_ref, is_data_ref
from keras_remote.data.data import _PARALLEL_HASH_THRESHOLD


def _make_temp_path(test_case):
Expand Down Expand Up @@ -216,6 +217,34 @@ def test_path_included_in_hash(self):
Data(str(d1)).content_hash(), Data(str(d2)).content_hash()
)

def test_parallel_determinism_many_files(self):
"""Directory with many files exercises the thread pool path and
must still produce deterministic hashes."""
tmp = _make_temp_path(self)
d = tmp / "large_dir"
d.mkdir()
num_files = _PARALLEL_HASH_THRESHOLD + 30
for i in range(num_files):
(d / f"file_{i:04d}.txt").write_text(f"content_{i}")

hashes = [Data(str(d)).content_hash() for _ in range(5)]
self.assertTrue(all(h == hashes[0] for h in hashes))

def test_parallel_threshold_boundary(self):
"""Directories at and just above the threshold both produce valid
deterministic hashes."""
tmp = _make_temp_path(self)
for count in (_PARALLEL_HASH_THRESHOLD, _PARALLEL_HASH_THRESHOLD + 1):
d = tmp / f"dir_{count}"
d.mkdir()
for i in range(count):
(d / f"f{i}.txt").write_text(f"data{i}")

h1 = Data(str(d)).content_hash()
h2 = Data(str(d)).content_hash()
self.assertEqual(h1, h2)
self.assertEqual(len(h1), 64)


class TestMakeDataRef(absltest.TestCase):
def test_basic_ref(self):
Expand Down
Loading