77import hashlib
88import os
99import posixpath
10+ from concurrent .futures import ThreadPoolExecutor
11+ from functools import partial
1012
1113from absl import logging
1214
15+ # Directories with more files than this threshold are hashed in parallel
16+ # using a thread pool. Below this, sequential hashing avoids pool overhead.
17+ _PARALLEL_HASH_THRESHOLD = 16
18+ _HASH_BATCH_SIZE = 512
19+
20+
21+ def _hash_single_file (fpath : str , relpath : str ) -> bytes :
22+ """SHA-256 of relpath + \\ 0 + file contents. Returns raw 32-byte digest."""
23+ h = hashlib .sha256 ()
24+ h .update (relpath .encode ("utf-8" ))
25+ h .update (b"\0 " )
26+ with open (fpath , "rb" ) as f :
27+ for chunk in iter (partial (f .read , 65536 ), b"" ):
28+ h .update (chunk )
29+ return h .digest ()
30+
31+
32+ def _hash_file_batch (batch : list [tuple [str , str ]]) -> list [bytes ]:
33+ """Hash a batch of (relpath, fpath) pairs. Returns list of 32-byte digests."""
34+ return [_hash_single_file (fpath , relpath ) for relpath , fpath in batch ]
35+
1336
1437class Data :
1538 """A reference to data that should be available on the remote pod.
@@ -78,6 +101,10 @@ def is_dir(self) -> bool:
78101 def content_hash (self ) -> str :
79102 """SHA-256 hash of all file contents, sorted by relative path.
80103
104+ Uses two-level hashing for parallelism: each file is hashed
105+ independently (SHA-256 of relpath + contents), then per-file
106+ digests are combined in sorted order into a final hash.
107+
81108 Includes a type prefix ("dir:" or "file:") to prevent collisions
82109 between a single file and a directory containing only that file.
83110
@@ -88,34 +115,53 @@ def content_hash(self) -> str:
88115 """
89116 if self .is_gcs :
90117 raise ValueError ("Cannot compute content hash for GCS URI" )
118+ if os .path .isdir (self ._resolved_path ):
119+ return self ._content_hash_dir ()
120+ return self ._content_hash_file ()
91121
122+ def _content_hash_file (self ) -> str :
92123 h = hashlib .sha256 ()
93- if os .path .isdir (self ._resolved_path ):
94- h .update (b"dir:" )
95- for root , dirs , files in os .walk (self ._resolved_path , followlinks = False ):
96- dirs .sort ()
97- for fname in sorted (files ):
98- fpath = os .path .join (root , fname )
99- relpath = os .path .relpath (fpath , self ._resolved_path )
100- h .update (relpath .encode ("utf-8" ))
101- h .update (b"\0 " )
102- with open (fpath , "rb" ) as f :
103- while True :
104- chunk = f .read (65536 ) # 64 KB chunks
105- if not chunk :
106- break
107- h .update (chunk )
108- h .update (b"\0 " )
124+ h .update (b"file:" )
125+ h .update (
126+ _hash_single_file (
127+ self ._resolved_path ,
128+ os .path .basename (self ._resolved_path ),
129+ )
130+ )
131+ return h .hexdigest ()
132+
133+ def _content_hash_dir (self ) -> str :
134+ # Enumerate all files. Walk in filesystem order (better disk
135+ # locality) and sort once at the end for determinism.
136+ file_list = []
137+ for root , _dirs , files in os .walk (self ._resolved_path , followlinks = False ):
138+ for fname in files :
139+ fpath = os .path .join (root , fname )
140+ relpath = os .path .relpath (fpath , self ._resolved_path )
141+ file_list .append ((relpath , fpath ))
142+ file_list .sort ()
143+
144+ # Hash each file independently. Use a thread pool for large
145+ # directories to parallelize I/O-bound reads. Work is batched
146+ # to avoid creating one Future per file.
147+ if len (file_list ) <= _PARALLEL_HASH_THRESHOLD :
148+ digests = _hash_file_batch (file_list )
109149 else :
110- h .update (b"file:" )
111- h .update (os .path .basename (self ._resolved_path ).encode ("utf-8" ))
112- h .update (b"\0 " )
113- with open (self ._resolved_path , "rb" ) as f :
114- while True :
115- chunk = f .read (65536 )
116- if not chunk :
117- break
118- h .update (chunk )
150+ batches = [
151+ file_list [i : i + _HASH_BATCH_SIZE ]
152+ for i in range (0 , len (file_list ), _HASH_BATCH_SIZE )
153+ ]
154+ max_workers = min (32 , (os .cpu_count () or 4 ) + 4 )
155+ with ThreadPoolExecutor (max_workers = max_workers ) as pool :
156+ digests = []
157+ for batch_digests in pool .map (_hash_file_batch , batches ):
158+ digests .extend (batch_digests )
159+
160+ # Combine per-file digests (each exactly 32 bytes) into final hash.
161+ h = hashlib .sha256 ()
162+ h .update (b"dir:" )
163+ for digest in digests :
164+ h .update (digest )
119165 return h .hexdigest ()
120166
121167 def __repr__ (self ):
0 commit comments