Skip to content
This repository was archived by the owner on Apr 30, 2026. It is now read-only.

Commit 70ecbf8

Browse files
eshwarprasadSmergify[bot]
authored andcommitted
fix: remove redundant and unused legacy methods
Signed-off-by: eshwarprasadS <eshwarprasad.s01@gmail.com> (cherry picked from commit f614e96)
1 parent e597cb6 commit 70ecbf8

1 file changed

Lines changed: 1 addition & 153 deletions

File tree

src/instructlab/sdg/subset_selection.py

Lines changed: 1 addition & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Standard
22
from dataclasses import dataclass, field
33
from multiprocessing import Pool
4-
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
4+
from typing import Any, Dict, List, TypedDict, TypeVar, Union
55
import gc
66
import glob
77
import logging
@@ -276,60 +276,6 @@ def get_subset_name(self, size_spec: Union[int, float], actual_size: int) -> str
276276
return f"percent_{size_spec:.1f}"
277277
return f"samples_{actual_size}"
278278

279-
def get_last_processed_batch(self, output_dir: str) -> Tuple[int, Optional[str]]:
280-
"""
281-
Retrieves the last processed batch number and its file path from the output directory.
282-
283-
Args:
284-
output_dir (str): The directory where batch files are stored.
285-
286-
Returns:
287-
Tuple[int, Optional[str]]: The last batch number and the corresponding batch file path.
288-
"""
289-
batch_files = glob.glob(os.path.join(output_dir, "batch_*.h5"))
290-
if not batch_files:
291-
return -1, None
292-
293-
# Sort batch files by batch number
294-
batch_files.sort(key=self.extract_batch_number)
295-
max_batch_file = batch_files[-1]
296-
max_batch_number = self.extract_batch_number(max_batch_file)
297-
298-
# Return the max batch number and the corresponding batch file path
299-
return max_batch_number, max_batch_file
300-
301-
@retry_on_exception
302-
def process_batch(self, batch_texts: List[str], output_file: str) -> Optional[int]:
303-
"""
304-
Processes a batch of texts by generating embeddings and saving them to a file.
305-
Returns the embedding dimension or None if no embeddings were generated.
306-
"""
307-
embeddings = (
308-
self.encoder.encode(
309-
inputs=batch_texts,
310-
instruction=self.config.encoder.instruction,
311-
)
312-
.cpu()
313-
.numpy()
314-
)
315-
316-
if embeddings.size == 0:
317-
logger.warning(
318-
f"No embeddings generated for batch, skipping file {output_file}"
319-
)
320-
return None
321-
322-
embedding_dim = int(embeddings.shape[1]) # Cast to int
323-
logger.info(f"Embedding dimension for batch: {embedding_dim}")
324-
325-
with h5py.File(output_file, "w") as h5f:
326-
h5f.create_dataset(
327-
"embeddings", data=embeddings, dtype="float32", chunks=True
328-
)
329-
h5f.flush()
330-
331-
return embedding_dim
332-
333279
@retry_on_exception
334280
def generate_embeddings(self, dataset, output_dir: str) -> str:
335281
"""
@@ -399,104 +345,6 @@ def generate_embeddings(self, dataset, output_dir: str) -> str:
399345

400346
return merged_path
401347

402-
def extract_batch_number(self, filename):
403-
"""
404-
Extracts the batch number from the filename.
405-
Assumes the filename is in the format 'batch_<number>.h5'.
406-
407-
Args:
408-
filename (str): The filename from which to extract the batch number.
409-
410-
Returns:
411-
int: The batch number extracted from the filename.
412-
"""
413-
basename = os.path.basename(filename)
414-
match = re.search(r"batch_(\d+)\.h5$", basename)
415-
if match:
416-
return int(match.group(1))
417-
raise ValueError(f"Filename {filename} does not match expected pattern.")
418-
419-
def get_embedding_size_dim_from_file(self, batch_file: str) -> Tuple[int, int]:
420-
"""
421-
Reads the batch file to determine the embedding size (number of embeddings) and dimension.
422-
"""
423-
with h5py.File(batch_file, "r") as h5f:
424-
if "embeddings" not in h5f:
425-
raise ValueError(
426-
f"The file {batch_file} does not contain 'embeddings' dataset."
427-
)
428-
embeddings = h5f["embeddings"]
429-
embedding_size = int(embeddings.shape[0]) # Cast to int
430-
embedding_dim = int(embeddings.shape[1]) # Cast to int
431-
logger.info(f"Embedding dimension from {batch_file}: {embedding_dim}")
432-
return embedding_size, embedding_dim
433-
434-
def merge_embeddings(self, output_dir, merged_file, total_samples):
435-
"""
436-
Merges all batch embedding files into a single embeddings file.
437-
438-
Args:
439-
output_dir (str): The directory where batch embedding files are stored.
440-
merged_file (str): The path to the merged embeddings file.
441-
total_samples (int): The total number of samples (embeddings).
442-
443-
"""
444-
# Find all batch files
445-
batch_files = glob.glob(os.path.join(output_dir, "batch_*.h5"))
446-
if not batch_files:
447-
logger.warning("No batch files found to merge")
448-
return
449-
450-
# Sort batch files by batch number
451-
batch_files.sort(key=self.extract_batch_number)
452-
453-
# Retrieve embedding_dim from the first batch file
454-
_, embedding_dim = self.get_embedding_size_dim_from_file(batch_files[0])
455-
456-
if os.path.exists(merged_file):
457-
logger.info(f"Merged file {merged_file} already exists, skipping merge")
458-
return
459-
460-
logger.info(
461-
f"Merging {len(batch_files)} batch files into {merged_file} with {total_samples} samples"
462-
)
463-
464-
with h5py.File(merged_file, "w") as h5f_merged:
465-
# Initialize the dataset in the merged file with the retrieved embedding dimension
466-
embeddings_ds = h5f_merged.create_dataset(
467-
"embeddings", shape=(total_samples, embedding_dim), dtype="float32"
468-
)
469-
470-
start_idx = 0
471-
for batch_file in batch_files:
472-
with h5py.File(batch_file, "r") as h5f_batch:
473-
if "embeddings" not in h5f_batch:
474-
logger.error(
475-
f"File {batch_file} does not contain 'embeddings' dataset"
476-
)
477-
continue
478-
479-
embeddings = h5f_batch["embeddings"][:]
480-
batch_size = embeddings.shape[0]
481-
end_idx = start_idx + batch_size
482-
483-
# Check that each file's embedding dimension matches the retrieved embedding_dim
484-
if embeddings.shape[1] != embedding_dim:
485-
logger.error(
486-
f"Embedding dimension mismatch in {batch_file}. Expected {embedding_dim}, got {embeddings.shape[1]}"
487-
)
488-
continue
489-
490-
# Copy embeddings into the merged dataset
491-
embeddings_ds[start_idx:end_idx] = embeddings
492-
start_idx = end_idx
493-
494-
# Remove the batch file after processing
495-
os.remove(batch_file)
496-
logger.info(f"Processed and removed {batch_file}")
497-
498-
gc.collect()
499-
500348
def select_subsets(
501349
self, dataset_name: str, embeddings: torch.Tensor
502350
) -> Dict[Union[int, float], List[int]]:

0 commit comments

Comments
 (0)