Skip to content

Commit 6e7406a

Browse files
author
Sivan Ravid
committed
tests and examples update
1 parent 92ce269 commit 6e7406a

10 files changed

Lines changed: 568 additions & 108 deletions

File tree

vllm/README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
# vLLM BiomedRNA Model Plugins
1+
# vLLM BiomedRNA Model Plugin
22

3-
Running BiomedRNA models with vLLM through the plugin system.
3+
Running Inference for BiomedRNA models via VLLM plugin.
44

55
## Installation
66

7-
Install biomed-multi-omic and vllm plugin:
7+
add vllm plugin to your bmfm-multi-omic env:
88

99
```
10-
cd $HOME/git/biomed-multi-omic
11-
pip install -e ..
12-
cd vllm
13-
pip install -e .
10+
uv pip install -e .
1411
```
15-
`
1612

1713
## Prerequisites
1814

vllm/examples/biomed_rna_example.py

Lines changed: 95 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
This example demonstrates two approaches:
66
1. Single batch processing (quick test with few cells)
77
2. Full file iteration (memory-efficient processing of entire dataset)
8+
9+
Both approaches use the same iteration pattern for consistency.
810
"""
911

1012
from pathlib import Path
@@ -16,15 +18,90 @@
1618
from vllm_biomed_rna_plugin.biomed_rna import (
1719
BiomedRnaForSequenceEmbedding, # Register model class
1820
)
19-
from vllm_biomed_rna_plugin.preprocess import iter_h5ad_batches, preprocess_anndata
21+
from vllm_biomed_rna_plugin.preprocess import preprocess_anndata
2022
from vllm_biomed_rna_plugin.utils import DEFAULT_MODEL_PATH, load_tokenizer
2123

2224
# Configuration
23-
H5AD_PATH: Path = Path("examples/resources/zheng68k.h5ad")
25+
ZHENG_SMALL_H5AD_PATH: Path = Path("examples/resources/zheng68k.h5ad") #165 samples
26+
27+
28+
def iter_h5ad_batches(
29+
h5ad_path: str | Path,
30+
tokenizer,
31+
batch_size: int = 32,
32+
max_length: int = 1024,
33+
limit_genes: str = "protein_coding",
34+
log_normalize_transform: bool = True,
35+
limit_cells: int | None = None,
36+
):
37+
"""
38+
Stream batches from h5ad file using DataModule preprocessing.
39+
40+
Memory-efficient processing with full bmfm-targets transformations:
41+
- Log normalization (if enabled)
42+
- Gene filtering (e.g., protein_coding only)
43+
- Sequence length limiting (max_length)
44+
- Attention mask generation
45+
46+
Uses backed="r" mode to avoid loading entire file into memory.
47+
48+
Args:
49+
----
50+
h5ad_path: Path to h5ad file
51+
tokenizer: MultiFieldTokenizer from bmfm-targets
52+
batch_size: Number of cells per batch (default: 32)
53+
max_length: Maximum sequence length (default: 1024)
54+
limit_genes: Gene filtering strategy - "protein_coding" or None (default: "protein_coding")
55+
log_normalize_transform: Apply log normalization (default: True)
56+
limit_cells: Optional limit on total cells to process (default: None = all cells)
57+
58+
Yields:
59+
------
60+
list[dict]: Batch of preprocessed cells in vLLM format
61+
62+
Example:
63+
-------
64+
>>> tokenizer = load_tokenizer()
65+
>>> llm = get_vllm_biomed_rna_model()
66+
>>>
67+
>>> all_embeddings = []
68+
>>> for batch in iter_h5ad_batches("data.h5ad", tokenizer, batch_size=32):
69+
>>> outputs = llm.embed(batch)
70+
>>> embeddings = [out.outputs.embedding for out in outputs]
71+
>>> all_embeddings.extend(embeddings)
72+
>>>
73+
>>> embeddings_array = np.array(all_embeddings) # [n_cells, hidden_size]
74+
"""
75+
# backed="r" = read-only mode, doesn't load full matrix into memory
76+
adata = anndata.read_h5ad(str(h5ad_path), backed="r")
77+
total_cells = adata.n_obs if limit_cells is None else min(limit_cells, adata.n_obs)
78+
79+
cells_processed = 0
80+
for start in range(0, total_cells, batch_size):
81+
end = min(start + batch_size, total_cells)
82+
83+
# Load chunk into memory
84+
chunk_adata = adata[start:end].to_memory()
85+
86+
# Preprocess using DataModule (applies all transformations)
87+
batch = preprocess_anndata(
88+
chunk_adata,
89+
tokenizer,
90+
max_length=max_length,
91+
limit_genes=limit_genes,
92+
log_normalize_transform=log_normalize_transform,
93+
batch_size=None, # Process entire chunk at once
94+
)
95+
96+
yield batch
97+
98+
cells_processed = end
99+
if limit_cells and cells_processed >= limit_cells:
100+
break
24101

25102

26103
def generate_embedding_for_h5ad_snippet(
27-
h5ad_path: Path = H5AD_PATH,
104+
h5ad_path: Path = ZHENG_SMALL_H5AD_PATH,
28105
num_samples: int = 10,
29106
max_length: int = 1024,
30107
) -> np.ndarray:
@@ -74,8 +151,8 @@ def generate_embedding_for_h5ad_snippet(
74151
return embeddings
75152

76153

77-
def generate_embeddings_for_full_h5ad(
78-
h5ad_path: Path = H5AD_PATH,
154+
def generate_embeddings_for_h5ad(
155+
h5ad_path: Path = ZHENG_SMALL_H5AD_PATH,
79156
batch_size: int = 32,
80157
max_length: int = 1024,
81158
limit_cells: int | None = None,
@@ -84,7 +161,8 @@ def generate_embeddings_for_full_h5ad(
84161
Generate embeddings for entire h5ad file using batch iteration.
85162
86163
Memory-efficient: Processes file in chunks without loading everything
87-
into memory at once. Uses DataModule preprocessing for each batch.
164+
into memory at once. Uses the iter_h5ad_batches helper function which
165+
applies full DataModule preprocessing for each batch.
88166
89167
Args:
90168
----
@@ -101,41 +179,32 @@ def generate_embeddings_for_full_h5ad(
101179
print(f"Example 2: Full File Iteration (batch_size={batch_size})")
102180
print(f"{'='*80}")
103181

182+
# Initialize tokenizer and model
104183
tokenizer = load_tokenizer()
105-
llm = get_vllm_biomed_rna_model(
106-
model_path=DEFAULT_MODEL_PATH,
107-
)
184+
llm = get_vllm_biomed_rna_model(model_path=DEFAULT_MODEL_PATH)
108185

109-
# Get total cell count
186+
# Get total cell count for progress reporting
110187
adata_info = anndata.read_h5ad(h5ad_path, backed="r")
111188
total_cells = (
112189
adata_info.n_obs if limit_cells is None else min(limit_cells, adata_info.n_obs)
113190
)
114191
print(f"Processing {total_cells} cells from {h5ad_path.name}")
115192

116-
# Process in batches
193+
# Process in batches using the iteration helper
117194
all_embeddings = []
118-
cells_processed = 0
119-
120195
for batch in iter_h5ad_batches(
121-
str(h5ad_path),
196+
h5ad_path,
122197
tokenizer,
123198
batch_size=batch_size,
124199
max_length=max_length,
200+
limit_cells=limit_cells,
125201
):
126202
# Generate embeddings for this batch
127203
outputs = llm.embed(batch)
128-
batch_embeddings: list[list[float]] = [
129-
output.outputs.embedding for output in outputs
130-
]
204+
batch_embeddings = [output.outputs.embedding for output in outputs]
131205
all_embeddings.extend(batch_embeddings)
132-
133-
cells_processed += len(batch)
134-
print(f" Processed {cells_processed}/{total_cells} cells...")
135-
136-
# Stop if we've reached the limit
137-
if limit_cells and cells_processed >= limit_cells:
138-
break
206+
207+
print(f" Processed {len(all_embeddings)}/{total_cells} cells...")
139208

140209
# Convert to numpy array
141210
embeddings = np.array(all_embeddings)
@@ -149,7 +218,7 @@ def generate_embeddings_for_full_h5ad(
149218
# Example 1: Quick test with 10 cells
150219
embeddings_snippet = generate_embedding_for_h5ad_snippet(num_samples=10)
151220

152-
# Example 2: Process more cells using batch iteration
153-
embeddings_full = generate_embeddings_for_full_h5ad(
221+
# Example 2: Process full h5ad file using batch iteration
222+
embeddings_full = generate_embeddings_for_h5ad(
154223
batch_size=32,
155224
)

vllm/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
]
2727

2828
dependencies = [
29-
"vllm>=0.13.0",
29+
"vllm>0.18.0",
3030
"torch>=2.9.0",
3131
"transformers>=4.56.0,<5",
3232
]

vllm/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests package for vllm-biomed-rna-plugin."""

vllm/tests/conftest.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Shared pytest fixtures for BiomedRNA tests."""
2+
3+
import os
4+
from pathlib import Path
5+
6+
import pytest
7+
import torch
8+
from transformers import AutoConfig
9+
10+
from vllm_biomed_rna_plugin.biomed_rna import (
11+
BiomedRnaConfig,
12+
BiomedRnaForSequenceEmbedding,
13+
)
14+
15+
16+
def pytest_configure(config):
17+
"""Configure pytest and set environment variables for PyTorch."""
18+
# Disable TorchInductor compilation warnings
19+
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
20+
os.environ["TORCH_COMPILE_DEBUG"] = "0"
21+
22+
# Use eager mode to avoid compilation issues
23+
torch._dynamo.config.suppress_errors = True
24+
25+
# Set deterministic behavior
26+
torch.use_deterministic_algorithms(False)
27+
torch.backends.cudnn.deterministic = True
28+
torch.backends.cudnn.benchmark = False
29+
30+
31+
LOCAL_MODEL_PATH = Path(
32+
"/dccstor/bmfm-targets1/users/sivanra/models/biomed.rna.llama.47m.wced.multitask.v1"
33+
)
34+
HF_MODEL_PATH = "ibm-research/biomed.rna.llama.47m.wced.multitask.v1"
35+
36+
__all__ = [
37+
"create_dummy_vllm_config",
38+
"create_rna_multi_modal_object",
39+
"config",
40+
"model",
41+
]
42+
43+
44+
def create_rna_multi_modal_object(
45+
gene_ids: torch.Tensor, expr_values: torch.Tensor
46+
) -> dict:
47+
"""Create a multimodal data object for RNA input."""
48+
return {
49+
"rna": {
50+
"gene_ids": gene_ids,
51+
"expr_values": expr_values,
52+
}
53+
}
54+
55+
56+
def create_dummy_vllm_config(config: BiomedRnaConfig):
57+
"""Create minimal vLLM config for testing."""
58+
59+
class DummyPoolerConfig:
60+
seq_pooling_type = "CLS"
61+
62+
class DummyMultiModalConfig:
63+
"""Dummy multimodal config for testing."""
64+
65+
# Required by SupportsMultiModal interface
66+
mm_encoder_only = False
67+
68+
def get_limit_per_prompt(self, modality: str) -> int | None:
69+
"""Return None to indicate no limit for the modality."""
70+
return None
71+
72+
class DummyModelConfig:
73+
def __init__(self, hf_config):
74+
self.hf_config = hf_config
75+
self.dtype = torch.float32
76+
self.head_dtype = torch.float32
77+
self.pooler_config = DummyPoolerConfig()
78+
self.multimodal_config = DummyMultiModalConfig()
79+
80+
class DummyVllmConfig:
81+
def __init__(self, hf_config):
82+
self.model_config = DummyModelConfig(hf_config)
83+
84+
return DummyVllmConfig(config)
85+
86+
87+
@pytest.fixture(scope="module")
88+
def config():
89+
return AutoConfig.from_pretrained(LOCAL_MODEL_PATH)
90+
91+
92+
@pytest.fixture(scope="module")
93+
def model(config):
94+
"""Pytest fixture for BiomedRNA model with loaded weights."""
95+
from safetensors.torch import load_file
96+
97+
# Load weights
98+
weights = load_file(str(LOCAL_MODEL_PATH / "model.safetensors"))
99+
100+
# Create model with full config
101+
vllm_config = create_dummy_vllm_config(config)
102+
model = BiomedRnaForSequenceEmbedding(vllm_config=vllm_config)
103+
model.load_weights(weights.items())
104+
model.eval()
105+
return model

0 commit comments

Comments
 (0)