Skip to content

Commit ec68497

Browse files
committed
[WIP] nicheformer-1
Adding nicheformer to the portfolio. This is untested yet Populating the stump with nicheformer Changing version requirements to fit Nicheformer test fixture fix
1 parent bc2b9cd commit ec68497

File tree

9 files changed

+542
-2
lines changed

9 files changed

+542
-2
lines changed

ci/tests/test_nicheformer/__init__.py

Whitespace-only changes.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import numpy as np
2+
from pathlib import Path
3+
4+
from helical.models.nicheformer import NicheformerConfig
5+
6+
_HF_BASE_URL = "https://huggingface.co/theislab/Nicheformer/resolve/main"
7+
_EXPECTED_FILES = {
8+
"config.json",
9+
"vocab.json",
10+
"model.safetensors",
11+
"model.h5ad",
12+
"modeling_nicheformer.py",
13+
"tokenization_nicheformer.py",
14+
"configuration_nicheformer.py",
15+
"masking.py",
16+
"__init__.py",
17+
}
18+
19+
20+
class TestNicheformerConfig:
21+
def test_default_batch_size(self):
22+
assert NicheformerConfig().config["batch_size"] == 32
23+
24+
def test_default_device(self):
25+
assert NicheformerConfig().config["device"] == "cpu"
26+
27+
def test_default_layer(self):
28+
assert NicheformerConfig().config["layer"] == -1
29+
30+
def test_default_with_context(self):
31+
assert NicheformerConfig().config["with_context"] is False
32+
33+
def test_default_technology_mean(self):
34+
assert NicheformerConfig().config["technology_mean"] is None
35+
36+
def test_default_model_name(self):
37+
assert NicheformerConfig().config["model_name"] == "theislab/Nicheformer"
38+
39+
def test_custom_batch_size(self):
40+
assert NicheformerConfig(batch_size=16).config["batch_size"] == 16
41+
42+
def test_custom_device(self):
43+
assert NicheformerConfig(device="cuda").config["device"] == "cuda"
44+
45+
def test_custom_layer(self):
46+
assert NicheformerConfig(layer=6).config["layer"] == 6
47+
48+
def test_custom_with_context(self):
49+
assert NicheformerConfig(with_context=True).config["with_context"] is True
50+
51+
def test_technology_mean_as_ndarray(self):
52+
arr = np.ones(100)
53+
config = NicheformerConfig(technology_mean=arr)
54+
assert config.config["technology_mean"] is arr
55+
56+
def test_technology_mean_as_path_string(self):
57+
config = NicheformerConfig(technology_mean="path/to/mean.npy")
58+
assert config.config["technology_mean"] == "path/to/mean.npy"
59+
60+
def test_files_to_download_count(self):
61+
config = NicheformerConfig()
62+
assert len(config.list_of_files_to_download) == len(_EXPECTED_FILES)
63+
64+
def test_files_to_download_are_path_url_tuples(self):
65+
config = NicheformerConfig()
66+
for local_path, url in config.list_of_files_to_download:
67+
assert isinstance(local_path, Path)
68+
assert isinstance(url, str)
69+
70+
def test_files_to_download_urls_point_to_hf(self):
71+
config = NicheformerConfig()
72+
for _, url in config.list_of_files_to_download:
73+
assert url.startswith(_HF_BASE_URL)
74+
75+
def test_files_to_download_covers_all_expected_files(self):
76+
config = NicheformerConfig()
77+
downloaded_filenames = {
78+
Path(url).name for _, url in config.list_of_files_to_download
79+
}
80+
assert downloaded_filenames == _EXPECTED_FILES
81+
82+
def test_local_paths_are_under_model_dir(self):
83+
config = NicheformerConfig()
84+
for local_path, _ in config.list_of_files_to_download:
85+
assert local_path.parent == config.model_dir
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import pytest
2+
import numpy as np
3+
import torch
4+
import anndata as ad
5+
from anndata import AnnData
6+
from scipy.sparse import csr_matrix
7+
from datasets import Dataset
8+
9+
from helical.models.nicheformer import Nicheformer, NicheformerConfig
10+
11+
12+
@pytest.fixture
13+
def _mocks(mocker):
14+
"""Patch all I/O so Nicheformer can be instantiated without network or disk access."""
15+
mocker.patch("helical.models.nicheformer.model.Downloader")
16+
17+
mock_tokenizer = mocker.MagicMock()
18+
19+
def _tokenize(adata, **kwargs):
20+
n = adata.n_obs
21+
return {
22+
"input_ids": torch.zeros((n, 1500), dtype=torch.long),
23+
"attention_mask": torch.ones((n, 1500), dtype=torch.bool),
24+
}
25+
26+
mock_tokenizer.side_effect = _tokenize
27+
mocker.patch(
28+
"helical.models.nicheformer.model.AutoTokenizer.from_pretrained",
29+
return_value=mock_tokenizer,
30+
)
31+
32+
mock_model = mocker.MagicMock()
33+
34+
def _get_embeddings(input_ids, attention_mask, layer, with_context):
35+
return torch.zeros((input_ids.shape[0], 512))
36+
37+
mock_model.get_embeddings.side_effect = _get_embeddings
38+
mock_model.to.return_value = mock_model
39+
mocker.patch(
40+
"helical.models.nicheformer.model.AutoModelForMaskedLM.from_pretrained",
41+
return_value=mock_model,
42+
)
43+
44+
return mock_tokenizer, mock_model
45+
46+
47+
@pytest.fixture
48+
def nicheformer(_mocks):
49+
return Nicheformer()
50+
51+
52+
@pytest.fixture
53+
def mock_adata():
54+
adata = AnnData(X=np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32))
55+
adata.obs_names = ["cell1", "cell2", "cell3"]
56+
adata.var_names = ["GENE1", "GENE2", "GENE3"]
57+
return adata
58+
59+
60+
@pytest.fixture
61+
def mock_adata_with_obs(mock_adata):
62+
adata = mock_adata.copy()
63+
adata.obs["modality"] = ["dissociated", "spatial", "dissociated"]
64+
adata.obs["specie"] = ["human", "human", "mouse"]
65+
adata.obs["assay"] = ["10x 3' v3", "MERFISH", "10x 3' v2"]
66+
return adata
67+
68+
69+
class TestNicheformerProcessData:
70+
def test_returns_dataset(self, nicheformer, mock_adata):
71+
dataset = nicheformer.process_data(mock_adata)
72+
assert isinstance(dataset, Dataset)
73+
74+
def test_dataset_has_input_ids_column(self, nicheformer, mock_adata):
75+
dataset = nicheformer.process_data(mock_adata)
76+
assert "input_ids" in dataset.features
77+
78+
def test_dataset_has_attention_mask_column(self, nicheformer, mock_adata):
79+
dataset = nicheformer.process_data(mock_adata)
80+
assert "attention_mask" in dataset.features
81+
82+
def test_dataset_length_matches_n_obs(self, nicheformer, mock_adata):
83+
dataset = nicheformer.process_data(mock_adata)
84+
assert len(dataset) == mock_adata.n_obs
85+
86+
def test_input_ids_sequence_length(self, nicheformer, mock_adata):
87+
dataset = nicheformer.process_data(mock_adata)
88+
assert len(dataset["input_ids"][0]) == 1500
89+
90+
def test_attention_mask_is_boolean(self, nicheformer, mock_adata):
91+
dataset = nicheformer.process_data(mock_adata)
92+
assert np.array(dataset["attention_mask"]).dtype == bool
93+
94+
def test_obs_metadata_columns_accepted(self, nicheformer, mock_adata_with_obs):
95+
dataset = nicheformer.process_data(mock_adata_with_obs)
96+
assert len(dataset) == mock_adata_with_obs.n_obs
97+
98+
def test_sparse_matrix_input_accepted(self, nicheformer, mock_adata):
99+
mock_adata.X = csr_matrix(mock_adata.X)
100+
dataset = nicheformer.process_data(mock_adata)
101+
assert len(dataset) == mock_adata.n_obs
102+
103+
def test_float_counts_raises_value_error(self, nicheformer):
104+
adata = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
105+
adata.X = adata.X.astype(float)
106+
adata.X[0, 0] = 0.5
107+
with pytest.raises(ValueError):
108+
nicheformer.process_data(adata, gene_names="index")
109+
110+
def test_missing_gene_names_column_raises_key_error(self, nicheformer, mock_adata):
111+
with pytest.raises(KeyError):
112+
nicheformer.process_data(mock_adata, gene_names="nonexistent_col")
113+
114+
115+
class TestNicheformerGetEmbeddings:
116+
def test_returns_ndarray(self, nicheformer, mock_adata):
117+
dataset = nicheformer.process_data(mock_adata)
118+
embeddings = nicheformer.get_embeddings(dataset)
119+
assert isinstance(embeddings, np.ndarray)
120+
121+
def test_embedding_shape(self, nicheformer, mock_adata):
122+
dataset = nicheformer.process_data(mock_adata)
123+
embeddings = nicheformer.get_embeddings(dataset)
124+
assert embeddings.shape == (mock_adata.n_obs, 512)
125+
126+
def test_batching_produces_same_shape(self, nicheformer, mock_adata):
127+
dataset = nicheformer.process_data(mock_adata)
128+
129+
nicheformer.config["batch_size"] = 1
130+
embeddings_bs1 = nicheformer.get_embeddings(dataset)
131+
132+
nicheformer.config["batch_size"] = 32
133+
embeddings_bs32 = nicheformer.get_embeddings(dataset)
134+
135+
assert embeddings_bs1.shape == embeddings_bs32.shape
136+
137+
def test_layer_forwarded_to_model(self, nicheformer, mock_adata, _mocks):
138+
_, mock_model = _mocks
139+
nicheformer.config["layer"] = 6
140+
dataset = nicheformer.process_data(mock_adata)
141+
nicheformer.get_embeddings(dataset)
142+
assert mock_model.get_embeddings.call_args.kwargs["layer"] == 6
143+
144+
def test_with_context_forwarded_to_model(self, nicheformer, mock_adata, _mocks):
145+
_, mock_model = _mocks
146+
nicheformer.config["with_context"] = True
147+
dataset = nicheformer.process_data(mock_adata)
148+
nicheformer.get_embeddings(dataset)
149+
assert mock_model.get_embeddings.call_args.kwargs["with_context"] is True
150+
151+
152+
class TestNicheformerTechnologyMean:
153+
def test_none_does_not_call_load(self, _mocks, mocker):
154+
mock_tokenizer, _ = _mocks
155+
Nicheformer(NicheformerConfig(technology_mean=None))
156+
mock_tokenizer._load_technology_mean.assert_not_called()
157+
158+
def test_ndarray_calls_load_with_array(self, _mocks):
159+
mock_tokenizer, _ = _mocks
160+
arr = np.ones(100)
161+
Nicheformer(NicheformerConfig(technology_mean=arr))
162+
mock_tokenizer._load_technology_mean.assert_called_once_with(arr)
163+
164+
def test_path_string_calls_load_with_path(self, _mocks):
165+
mock_tokenizer, _ = _mocks
166+
Nicheformer(NicheformerConfig(technology_mean="path/to/mean.npy"))
167+
mock_tokenizer._load_technology_mean.assert_called_once_with("path/to/mean.npy")

helical/models/nicheformer/LICENSE

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
BSD 3-Clause License
2+
3+
Copyright (c) 2024, Theislab
4+
All rights reserved.
5+
6+
Redistribution and use in source and binary forms, with or without
7+
modification, are permitted provided that the following conditions are met:
8+
9+
1. Redistributions of source code must retain the above copyright notice, this
10+
list of conditions and the following disclaimer.
11+
12+
2. Redistributions in binary form must reproduce the above copyright notice,
13+
this list of conditions and the following disclaimer in the documentation
14+
and/or other materials provided with the distribution.
15+
16+
3. Neither the name of the copyright holder nor the names of its
17+
contributors may be used to endorse or promote products derived from
18+
this software without specific prior written permission.
19+
20+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .model import Nicheformer
2+
from .nicheformer_config import NicheformerConfig

0 commit comments

Comments
 (0)