@@ -89,12 +89,41 @@ def _load_manifest(self) -> Optional[np.ndarray]:
8989 logger .error (f"Manifest file not found: { path } " )
9090 return None
9191 try :
92- with open (path , "rb" ) as f :
93- data = np .frombuffer (f .read (), dtype = np .uint32 ).reshape (- 1 , 2 )
94- logger .info (f"Manifest loaded from { path } ({ len (data )} rows)." )
92+ file_size_bytes = path .stat ().st_size
93+ # Heuristic: older 2-field format is 8 bytes per entry (two uint32),
94+ # newer 3-field format is 16 bytes per entry (int32, int32, int64).
95+ if file_size_bytes % 16 == 0 :
96+ # New format with 3 fields (chunk_id, num_tokens, offset)
97+ manifest_dtype = np .dtype ([("chunk_id" , np .int32 ), ("num_tokens" , np .int32 ), ("offset" , np .int64 )])
98+ data_structured = np .fromfile (path , dtype = manifest_dtype )
99+ logger .info (
100+ f"Manifest loaded (3-field format) from { path } ({ data_structured .shape [0 ]} chunks). Expanding to per-row entries."
101+ )
102+ # Expand into per-row entries expected by downstream (chunk_id, row_in_chunk)
103+ chunk_ids = data_structured ["chunk_id" ].astype (np .uint32 )
104+ num_tokens_arr = data_structured ["num_tokens" ].astype (np .uint32 )
105+ # Compute total rows
106+ total_rows = int (num_tokens_arr .sum ())
107+ logger .info (f"Expanding manifest: total rows = { total_rows } " )
108+ # Pre-allocate array
109+ data = np .empty ((total_rows , 2 ), dtype = np .uint32 )
110+ row_ptr = 0
111+ for cid , ntok in zip (chunk_ids , num_tokens_arr ):
112+ data [row_ptr : row_ptr + ntok , 0 ] = cid # chunk_id column
113+ data [row_ptr : row_ptr + ntok , 1 ] = np .arange (ntok , dtype = np .uint32 ) # row index within chunk
114+ row_ptr += ntok
115+ elif file_size_bytes % 8 == 0 :
116+ # Legacy 2-field format already matches expected shape
117+ data = np .fromfile (path , dtype = np .uint32 ).reshape (- 1 , 2 )
118+ logger .info (f"Manifest loaded (legacy 2-field format) from { path } ({ data .shape [0 ]} rows)." )
119+ else :
120+ logger .error (
121+ f"Manifest file size ({ file_size_bytes } bytes) is not compatible with known formats (8 or 16 bytes per row)."
122+ )
123+ return None
95124 return data
96125 except ValueError as e :
97- logger .error (f"Error reshaping manifest data from { path } (expected Nx2) : { e } " )
126+ logger .error (f"Error parsing manifest data from { path } : { e } " )
98127 return None
99128 except OSError as e :
100129 logger .error (f"Error reading manifest file { path } : { e } " )
@@ -117,7 +146,7 @@ def _load_norm_stats(self) -> Optional[Dict[str, Any]]:
117146 logger .error (f"Error reading norm_stats file { path } : { e } " )
118147 return None
119148
120- @lru_cache (maxsize = 256 )
149+ @lru_cache (maxsize = 64 )
121150 def _load_chunk (self , chunk_path : str , layer_key : str , data_type : str ):
122151 """Loads entire HDF5 chunk from disk and caches"""
123152
@@ -129,14 +158,29 @@ def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str):
129158 logger .error (f"Chunk file not found for fetch: { chunk_path } " )
130159 raise
131160 except KeyError as e :
132- raise RuntimeError (f"Missing 'inputs' or 'targets' dataset in layer group '{ layer_key } ' of chunk { chunk_path } " ) from e
161+ raise RuntimeError (
162+ f"Missing 'inputs' or 'targets' dataset in layer group '{ layer_key } ' of chunk { chunk_path } "
163+ ) from e
133164 except Exception as e :
134165 logger .error (f"Failed to open chunk at { chunk_path } : { e } " )
135166 raise RuntimeError (f"Failed to access chunk HDF5 file: { chunk_path } " ) from e
136167
137168 def _fetch_slice (self , chunk_id : int , row_indices : np .ndarray ) -> bytes :
138169
139170 chunk_path = self .dataset_path / f"chunk_{ chunk_id } .h5"
171+ if not chunk_path .exists ():
172+ # Fall back to .hdf5 extension (newer generator default)
173+ alt_path = self .dataset_path / f"chunk_{ chunk_id } .hdf5"
174+ if alt_path .exists ():
175+ chunk_path = alt_path
176+ else :
177+ # Provide clearer error message before _open_h5 raises
178+ logger .error (
179+ "Chunk file for chunk_id %d not found with either .h5 or .hdf5 extension in %s" ,
180+ chunk_id ,
181+ self .dataset_path ,
182+ )
183+
140184 hf = _open_h5 (chunk_path )
141185
142186 try :
@@ -162,8 +206,8 @@ def _layer_sort_key(name: str) -> int:
162206 row_indices_h5 = row_indices
163207
164208 for i , lk in enumerate (layer_keys ):
165- input_data = self ._load_chunk (chunk_path , lk , ' inputs' )[row_indices_h5 , :]
166- target_data = self ._load_chunk (chunk_path , lk , ' targets' )[row_indices_h5 , :]
209+ input_data = self ._load_chunk (chunk_path , lk , " inputs" )[row_indices_h5 , :]
210+ target_data = self ._load_chunk (chunk_path , lk , " targets" )[row_indices_h5 , :]
167211 bufs .append (input_data .tobytes ())
168212 bufs .append (target_data .tobytes ())
169213 return b"" .join (bufs )
0 commit comments