55This example demonstrates two approaches:
661. Single batch processing (quick test with few cells)
772. Full file iteration (memory-efficient processing of entire dataset)
8+
9+ Both approaches use the same iteration pattern for consistency.
810"""
911
1012from pathlib import Path
1618from vllm_biomed_rna_plugin .biomed_rna import (
1719 BiomedRnaForSequenceEmbedding , # Register model class
1820)
19- from vllm_biomed_rna_plugin .preprocess import iter_h5ad_batches , preprocess_anndata
21+ from vllm_biomed_rna_plugin .preprocess import preprocess_anndata
2022from vllm_biomed_rna_plugin .utils import DEFAULT_MODEL_PATH , load_tokenizer
2123
2224# Configuration
23- H5AD_PATH : Path = Path ("examples/resources/zheng68k.h5ad" )
25+ ZHENG_SMALL_H5AD_PATH : Path = Path ("examples/resources/zheng68k.h5ad" ) #165 samples
26+
27+
28+ def iter_h5ad_batches (
29+ h5ad_path : str | Path ,
30+ tokenizer ,
31+ batch_size : int = 32 ,
32+ max_length : int = 1024 ,
33+ limit_genes : str = "protein_coding" ,
34+ log_normalize_transform : bool = True ,
35+ limit_cells : int | None = None ,
36+ ):
37+ """
38+ Stream batches from h5ad file using DataModule preprocessing.
39+
40+ Memory-efficient processing with full bmfm-targets transformations:
41+ - Log normalization (if enabled)
42+ - Gene filtering (e.g., protein_coding only)
43+ - Sequence length limiting (max_length)
44+ - Attention mask generation
45+
46+ Uses backed="r" mode to avoid loading entire file into memory.
47+
48+ Args:
49+ ----
50+ h5ad_path: Path to h5ad file
51+ tokenizer: MultiFieldTokenizer from bmfm-targets
52+ batch_size: Number of cells per batch (default: 32)
53+ max_length: Maximum sequence length (default: 1024)
54+ limit_genes: Gene filtering strategy - "protein_coding" or None (default: "protein_coding")
55+ log_normalize_transform: Apply log normalization (default: True)
56+ limit_cells: Optional limit on total cells to process (default: None = all cells)
57+
58+ Yields:
59+ ------
60+ list[dict]: Batch of preprocessed cells in vLLM format
61+
62+ Example:
63+ -------
64+ >>> tokenizer = load_tokenizer()
65+ >>> llm = get_vllm_biomed_rna_model()
66+ >>>
67+ >>> all_embeddings = []
68+ >>> for batch in iter_h5ad_batches("data.h5ad", tokenizer, batch_size=32):
69+ >>> outputs = llm.embed(batch)
70+ >>> embeddings = [out.outputs.embedding for out in outputs]
71+ >>> all_embeddings.extend(embeddings)
72+ >>>
73+ >>> embeddings_array = np.array(all_embeddings) # [n_cells, hidden_size]
74+ """
75+ # backed="r" = read-only mode, doesn't load full matrix into memory
76+ adata = anndata .read_h5ad (str (h5ad_path ), backed = "r" )
77+ total_cells = adata .n_obs if limit_cells is None else min (limit_cells , adata .n_obs )
78+
79+ cells_processed = 0
80+ for start in range (0 , total_cells , batch_size ):
81+ end = min (start + batch_size , total_cells )
82+
83+ # Load chunk into memory
84+ chunk_adata = adata [start :end ].to_memory ()
85+
86+ # Preprocess using DataModule (applies all transformations)
87+ batch = preprocess_anndata (
88+ chunk_adata ,
89+ tokenizer ,
90+ max_length = max_length ,
91+ limit_genes = limit_genes ,
92+ log_normalize_transform = log_normalize_transform ,
93+ batch_size = None , # Process entire chunk at once
94+ )
95+
96+ yield batch
97+
98+ cells_processed = end
99+ if limit_cells and cells_processed >= limit_cells :
100+ break
24101
25102
26103def generate_embedding_for_h5ad_snippet (
27- h5ad_path : Path = H5AD_PATH ,
104+ h5ad_path : Path = ZHENG_SMALL_H5AD_PATH ,
28105 num_samples : int = 10 ,
29106 max_length : int = 1024 ,
30107) -> np .ndarray :
@@ -74,8 +151,8 @@ def generate_embedding_for_h5ad_snippet(
74151 return embeddings
75152
76153
77- def generate_embeddings_for_full_h5ad (
78- h5ad_path : Path = H5AD_PATH ,
154+ def generate_embeddings_for_h5ad (
155+ h5ad_path : Path = ZHENG_SMALL_H5AD_PATH ,
79156 batch_size : int = 32 ,
80157 max_length : int = 1024 ,
81158 limit_cells : int | None = None ,
@@ -84,7 +161,8 @@ def generate_embeddings_for_full_h5ad(
84161 Generate embeddings for entire h5ad file using batch iteration.
85162
86163 Memory-efficient: Processes file in chunks without loading everything
87- into memory at once. Uses DataModule preprocessing for each batch.
164+ into memory at once. Uses the iter_h5ad_batches helper function which
165+ applies full DataModule preprocessing for each batch.
88166
89167 Args:
90168 ----
@@ -101,41 +179,32 @@ def generate_embeddings_for_full_h5ad(
101179 print (f"Example 2: Full File Iteration (batch_size={ batch_size } )" )
102180 print (f"{ '=' * 80 } " )
103181
182+ # Initialize tokenizer and model
104183 tokenizer = load_tokenizer ()
105- llm = get_vllm_biomed_rna_model (
106- model_path = DEFAULT_MODEL_PATH ,
107- )
184+ llm = get_vllm_biomed_rna_model (model_path = DEFAULT_MODEL_PATH )
108185
109- # Get total cell count
186+ # Get total cell count for progress reporting
110187 adata_info = anndata .read_h5ad (h5ad_path , backed = "r" )
111188 total_cells = (
112189 adata_info .n_obs if limit_cells is None else min (limit_cells , adata_info .n_obs )
113190 )
114191 print (f"Processing { total_cells } cells from { h5ad_path .name } " )
115192
116- # Process in batches
193+ # Process in batches using the iteration helper
117194 all_embeddings = []
118- cells_processed = 0
119-
120195 for batch in iter_h5ad_batches (
121- str ( h5ad_path ) ,
196+ h5ad_path ,
122197 tokenizer ,
123198 batch_size = batch_size ,
124199 max_length = max_length ,
200+ limit_cells = limit_cells ,
125201 ):
126202 # Generate embeddings for this batch
127203 outputs = llm .embed (batch )
128- batch_embeddings : list [list [float ]] = [
129- output .outputs .embedding for output in outputs
130- ]
204+ batch_embeddings = [output .outputs .embedding for output in outputs ]
131205 all_embeddings .extend (batch_embeddings )
132-
133- cells_processed += len (batch )
134- print (f" Processed { cells_processed } /{ total_cells } cells..." )
135-
136- # Stop if we've reached the limit
137- if limit_cells and cells_processed >= limit_cells :
138- break
206+
207+ print (f" Processed { len (all_embeddings )} /{ total_cells } cells..." )
139208
140209 # Convert to numpy array
141210 embeddings = np .array (all_embeddings )
@@ -149,7 +218,7 @@ def generate_embeddings_for_full_h5ad(
149218 # Example 1: Quick test with 10 cells
150219 embeddings_snippet = generate_embedding_for_h5ad_snippet (num_samples = 10 )
151220
152- # Example 2: Process more cells using batch iteration
153- embeddings_full = generate_embeddings_for_full_h5ad (
221+ # Example 2: Process full h5ad file using batch iteration
222+ embeddings_full = generate_embeddings_for_h5ad (
154223 batch_size = 32 ,
155224 )
0 commit comments