|
1 | 1 | # Standard |
2 | 2 | from dataclasses import dataclass, field |
3 | 3 | from multiprocessing import Pool |
4 | | -from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union |
| 4 | +from typing import Any, Dict, List, TypedDict, TypeVar, Union |
5 | 5 | import gc |
6 | 6 | import glob |
7 | 7 | import logging |
@@ -276,60 +276,6 @@ def get_subset_name(self, size_spec: Union[int, float], actual_size: int) -> str |
276 | 276 | return f"percent_{size_spec:.1f}" |
277 | 277 | return f"samples_{actual_size}" |
278 | 278 |
|
279 | | - def get_last_processed_batch(self, output_dir: str) -> Tuple[int, Optional[str]]: |
280 | | - """ |
281 | | - Retrieves the last processed batch number and its file path from the output directory. |
282 | | -
|
283 | | - Args: |
284 | | - output_dir (str): The directory where batch files are stored. |
285 | | -
|
286 | | - Returns: |
287 | | - Tuple[int, Optional[str]]: The last batch number and the corresponding batch file path. |
288 | | - """ |
289 | | - batch_files = glob.glob(os.path.join(output_dir, "batch_*.h5")) |
290 | | - if not batch_files: |
291 | | - return -1, None |
292 | | - |
293 | | - # Sort batch files by batch number |
294 | | - batch_files.sort(key=self.extract_batch_number) |
295 | | - max_batch_file = batch_files[-1] |
296 | | - max_batch_number = self.extract_batch_number(max_batch_file) |
297 | | - |
298 | | - # Return the max batch number and the corresponding batch file path |
299 | | - return max_batch_number, max_batch_file |
300 | | - |
301 | | - @retry_on_exception |
302 | | - def process_batch(self, batch_texts: List[str], output_file: str) -> Optional[int]: |
303 | | - """ |
304 | | - Processes a batch of texts by generating embeddings and saving them to a file. |
305 | | - Returns the embedding dimension or None if no embeddings were generated. |
306 | | - """ |
307 | | - embeddings = ( |
308 | | - self.encoder.encode( |
309 | | - inputs=batch_texts, |
310 | | - instruction=self.config.encoder.instruction, |
311 | | - ) |
312 | | - .cpu() |
313 | | - .numpy() |
314 | | - ) |
315 | | - |
316 | | - if embeddings.size == 0: |
317 | | - logger.warning( |
318 | | - f"No embeddings generated for batch, skipping file {output_file}" |
319 | | - ) |
320 | | - return None |
321 | | - |
322 | | - embedding_dim = int(embeddings.shape[1]) # Cast to int |
323 | | - logger.info(f"Embedding dimension for batch: {embedding_dim}") |
324 | | - |
325 | | - with h5py.File(output_file, "w") as h5f: |
326 | | - h5f.create_dataset( |
327 | | - "embeddings", data=embeddings, dtype="float32", chunks=True |
328 | | - ) |
329 | | - h5f.flush() |
330 | | - |
331 | | - return embedding_dim |
332 | | - |
333 | 279 | @retry_on_exception |
334 | 280 | def generate_embeddings(self, dataset, output_dir: str) -> str: |
335 | 281 | """ |
@@ -399,104 +345,6 @@ def generate_embeddings(self, dataset, output_dir: str) -> str: |
399 | 345 |
|
400 | 346 | return merged_path |
401 | 347 |
|
402 | | - def extract_batch_number(self, filename): |
403 | | - """ |
404 | | - Extracts the batch number from the filename. |
405 | | - Assumes the filename is in the format 'batch_<number>.h5'. |
406 | | -
|
407 | | - Args: |
408 | | - filename (str): The filename from which to extract the batch number. |
409 | | -
|
410 | | - Returns: |
411 | | - int: The batch number extracted from the filename. |
412 | | - """ |
413 | | - basename = os.path.basename(filename) |
414 | | - match = re.search(r"batch_(\d+)\.h5$", basename) |
415 | | - if match: |
416 | | - return int(match.group(1)) |
417 | | - raise ValueError(f"Filename {filename} does not match expected pattern.") |
418 | | - |
419 | | - def get_embedding_size_dim_from_file(self, batch_file: str) -> Tuple[int, int]: |
420 | | - """ |
421 | | - Reads the batch file to determine the embedding size (number of embeddings) and dimension. |
422 | | - """ |
423 | | - with h5py.File(batch_file, "r") as h5f: |
424 | | - if "embeddings" not in h5f: |
425 | | - raise ValueError( |
426 | | - f"The file {batch_file} does not contain 'embeddings' dataset." |
427 | | - ) |
428 | | - embeddings = h5f["embeddings"] |
429 | | - embedding_size = int(embeddings.shape[0]) # Cast to int |
430 | | - embedding_dim = int(embeddings.shape[1]) # Cast to int |
431 | | - logger.info(f"Embedding dimension from {batch_file}: {embedding_dim}") |
432 | | - return embedding_size, embedding_dim |
433 | | - |
434 | | - def merge_embeddings(self, output_dir, merged_file, total_samples): |
435 | | - """ |
436 | | - Merges all batch embedding files into a single embeddings file. |
437 | | -
|
438 | | - Args: |
439 | | - output_dir (str): The directory where batch embedding files are stored. |
440 | | - merged_file (str): The path to the merged embeddings file. |
441 | | - total_samples (int): The total number of samples (embeddings). |
442 | | -
|
443 | | - """ |
444 | | - # Find all batch files |
445 | | - batch_files = glob.glob(os.path.join(output_dir, "batch_*.h5")) |
446 | | - if not batch_files: |
447 | | - logger.warning("No batch files found to merge") |
448 | | - return |
449 | | - |
450 | | - # Sort batch files by batch number |
451 | | - batch_files.sort(key=self.extract_batch_number) |
452 | | - |
453 | | - # Retrieve embedding_dim from the first batch file |
454 | | - _, embedding_dim = self.get_embedding_size_dim_from_file(batch_files[0]) |
455 | | - |
456 | | - if os.path.exists(merged_file): |
457 | | - logger.info(f"Merged file {merged_file} already exists, skipping merge") |
458 | | - return |
459 | | - |
460 | | - logger.info( |
461 | | - f"Merging {len(batch_files)} batch files into {merged_file} with {total_samples} samples" |
462 | | - ) |
463 | | - |
464 | | - with h5py.File(merged_file, "w") as h5f_merged: |
465 | | - # Initialize the dataset in the merged file with the retrieved embedding dimension |
466 | | - embeddings_ds = h5f_merged.create_dataset( |
467 | | - "embeddings", shape=(total_samples, embedding_dim), dtype="float32" |
468 | | - ) |
469 | | - |
470 | | - start_idx = 0 |
471 | | - for batch_file in batch_files: |
472 | | - with h5py.File(batch_file, "r") as h5f_batch: |
473 | | - if "embeddings" not in h5f_batch: |
474 | | - logger.error( |
475 | | - f"File {batch_file} does not contain 'embeddings' dataset" |
476 | | - ) |
477 | | - continue |
478 | | - |
479 | | - embeddings = h5f_batch["embeddings"][:] |
480 | | - batch_size = embeddings.shape[0] |
481 | | - end_idx = start_idx + batch_size |
482 | | - |
483 | | - # Check that each file's embedding dimension matches the retrieved embedding_dim |
484 | | - if embeddings.shape[1] != embedding_dim: |
485 | | - logger.error( |
486 | | - f"Embedding dimension mismatch in {batch_file}. Expected {embedding_dim}, got {embeddings.shape[1]}" |
487 | | - ) |
488 | | - continue |
489 | | - |
490 | | - # Copy embeddings into the merged dataset |
491 | | - embeddings_ds[start_idx:end_idx] = embeddings |
492 | | - start_idx = end_idx |
493 | | - |
494 | | - # Remove the batch file after processing |
495 | | - os.remove(batch_file) |
496 | | - logger.info(f"Processed and removed {batch_file}") |
497 | | - |
498 | | - gc.collect() |
499 | | - |
500 | 348 | def select_subsets( |
501 | 349 | self, dataset_name: str, embeddings: torch.Tensor |
502 | 350 | ) -> Dict[Union[int, float], List[int]]: |
|
0 commit comments