Skip to content

Commit 362817d

Browse files
committed
Add tests for the CellAnnotator class
1 parent 46ea838 commit 362817d

File tree

5 files changed

+149
-51
lines changed

5 files changed

+149
-51
lines changed

tests/test_base_annotator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
class TestBaseAnnotator:
1111
@pytest.fixture
1212
def base_annotator(self):
13-
return BaseAnnotator(
14-
species="human", tissue="brain", stage="adult", cluster_key="leiden", model="gpt-4o-mini", max_tokens=100
15-
)
13+
return BaseAnnotator(species="human", tissue="brain", stage="adult", cluster_key="leiden", model="gpt-4o-mini")
1614

1715
@patch("cell_annotator.base_annotator.BaseAnnotator.query_openai")
1816
def test_query_openai(self, mock_query_openai, base_annotator):

tests/test_cell_annotator.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import pytest
2+
3+
from cell_annotator.cell_annotator import CellAnnotator
4+
5+
from .utils import expected_marker_genes, fibroblast_cell_types, get_example_data, neuronal_cell_types
6+
7+
8+
class TestCellAnnotator:
9+
@pytest.fixture
10+
def cell_annotator(self):
11+
adata = get_example_data(n_cells=200, n_samples=2)
12+
13+
return CellAnnotator(
14+
adata=adata,
15+
species="human",
16+
tissue="In vitro neurons and fibroblasts",
17+
stage="adult",
18+
cluster_key="leiden",
19+
sample_key="sample",
20+
model="gpt-4o-mini",
21+
)
22+
23+
@pytest.mark.openai()
24+
def test_get_expected_cell_type_markers(self, cell_annotator):
25+
cell_annotator.get_expected_cell_type_markers()
26+
expected_markers = cell_annotator.expected_marker_genes
27+
print("Expected Markers:", expected_markers)
28+
29+
assert expected_markers is not None
30+
assert isinstance(expected_markers, dict)
31+
32+
neuron_markers_found = any(
33+
any(neuron_synonym in key for neuron_synonym in neuronal_cell_types)
34+
and set(expected_marker_genes["Neuron"]).intersection(expected_markers[key])
35+
for key in expected_markers
36+
)
37+
fibroblast_markers_found = any(
38+
any(fibroblast_synonym in key for fibroblast_synonym in fibroblast_cell_types)
39+
and set(expected_marker_genes["Fibroblast"]).intersection(expected_markers[key])
40+
for key in expected_markers
41+
)
42+
43+
assert neuron_markers_found
44+
assert fibroblast_markers_found
45+
46+
@pytest.mark.openai()
47+
def test_annotate_clusters(self, cell_annotator):
48+
# Step 1: Call get_cluster_markers and run checks
49+
cell_annotator.get_cluster_markers(min_auc=0.6)
50+
51+
for sample in cell_annotator.sample_annotators.values():
52+
assert sample.marker_gene_dfs is not None
53+
assert sample.marker_genes is not None
54+
55+
for _cluster, df in sample.marker_gene_dfs.items():
56+
assert not df.empty
57+
assert "gene" in df.columns
58+
assert "specificity" in df.columns
59+
assert "auc" in df.columns
60+
61+
for _cluster, genes in sample.marker_genes.items():
62+
assert len(genes) > 0
63+
64+
# Step 2: Call annotate_clusters and run checks
65+
cell_annotator.expected_marker_genes = expected_marker_genes
66+
cell_annotator.annotate_clusters(min_markers=1)
67+
68+
for sample in cell_annotator.sample_annotators.values():
69+
print("Sample Annotation:\n", sample.annotation_df[["n_cells", "cell_type"]])
70+
71+
neuron_annotation_found = any(
72+
neuron_synonym in sample.annotation_dict["0"].cell_type for neuron_synonym in neuronal_cell_types
73+
)
74+
fibroblast_annotation_found = any(
75+
fibroblast_synonym in sample.annotation_dict["1"].cell_type
76+
for fibroblast_synonym in fibroblast_cell_types
77+
)
78+
79+
assert neuron_annotation_found
80+
assert fibroblast_annotation_found

tests/test_sample_annotator.py

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,12 @@
11
from unittest.mock import patch
22

3-
import numpy as np
43
import pandas as pd
54
import pytest
6-
import scanpy as sc
7-
from anndata import AnnData
8-
from numpy.random import binomial, negative_binomial
95

106
from cell_annotator._response_formats import CellTypeMappingOutput, PredictedCellTypeOutput
117
from cell_annotator.sample_annotator import SampleAnnotator
128

13-
# Declare the dictionary of expected marker genes
14-
expected_marker_genes = {
15-
"Neuron": ["MAP2", "NEFL", "RBFOX3", "SYN1", "GAP43", "DCX", "TUBB3", "NEUROD1", "STMN2", "ENO2"],
16-
"Fibroblast": ["COL1A1", "COL3A1", "VIM", "ACTA2", "FAP", "PDGFRA", "THY1", "FN1", "SPARC", "S100A4"],
17-
}
18-
19-
20-
def get_example_data(n_cells: int = 100) -> AnnData:
21-
"""Create example data for testing. Adapted from scanpy.
22-
23-
The data consists of two clusters with different marker genes. The first cluster is enriched for neuronal markers and the second cluster is enriched for fibroblast markers."""
24-
gene_names = expected_marker_genes["Neuron"] + expected_marker_genes["Fibroblast"]
25-
n_genes = len(gene_names)
26-
adata = AnnData(np.multiply(binomial(1, 0.15, (n_cells, n_genes)), negative_binomial(2, 0.25, (n_cells, n_genes))))
27-
adata.var_names = gene_names
28-
29-
# Create marker genes for the two clusters
30-
n_group_1 = np.floor(0.3 * n_cells).astype(int)
31-
n_group_2 = n_cells - n_group_1
32-
n_marker_genes = int(n_genes / 2)
33-
34-
adata.X[:n_group_1, :10] = np.multiply(
35-
binomial(1, 0.9, (n_group_1, n_marker_genes)), negative_binomial(1, 0.5, (n_group_1, n_marker_genes))
36-
)
37-
adata.X[n_group_1:, 10:] = np.multiply(
38-
binomial(1, 0.9, (n_group_2, n_marker_genes)), negative_binomial(1, 0.5, (n_group_2, n_marker_genes))
39-
)
40-
41-
# Create cluster according to groups
42-
adata.obs["leiden"] = pd.Categorical(np.concatenate((n_group_1 * ["0"], n_group_2 * ["1"])))
43-
44-
# filter, normalize and log transform the data
45-
sc.pp.filter_cells(adata, min_counts=2)
46-
adata.raw = adata.copy()
47-
sc.pp.normalize_total(adata, target_sum=1e4)
48-
sc.pp.log1p(adata)
49-
50-
return adata
9+
from .utils import expected_marker_genes, fibroblast_cell_types, get_example_data, neuronal_cell_types
5110

5211

5312
class TestSampleAnnotator:
@@ -57,9 +16,9 @@ def sample_annotator(self):
5716

5817
return SampleAnnotator(
5918
adata=adata,
60-
sample_name="sample1",
19+
sample_name="sample_1",
6120
species="human",
62-
tissue="brain",
21+
tissue="In vitro neurons and fibroblasts",
6322
stage="adult",
6423
cluster_key="leiden",
6524
model="gpt-4o-mini",
@@ -112,5 +71,13 @@ def test_annotate_clusters_actual(self, sample_annotator):
11271
sample_annotator.get_cluster_markers(min_auc=0.6)
11372
sample_annotator.annotate_clusters(min_markers=1, expected_marker_genes=expected_marker_genes)
11473

115-
assert sample_annotator.annotation_dict["0"].cell_type == "Neuron"
116-
assert sample_annotator.annotation_dict["1"].cell_type == "Fibroblast"
74+
neuron_annotation_found = any(
75+
neuron_synonym in sample_annotator.annotation_dict["0"].cell_type for neuron_synonym in neuronal_cell_types
76+
)
77+
fibroblast_annotation_found = any(
78+
fibroblast_synonym in sample_annotator.annotation_dict["1"].cell_type
79+
for fibroblast_synonym in fibroblast_cell_types
80+
)
81+
82+
assert neuron_annotation_found
83+
assert fibroblast_annotation_found

tests/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def setup_data():
2626
clust_mask = np.array([True, False, True, False])
2727

2828
# Create raw count data with gene names
29-
raw_counts = np.array([[1, 0], [0, 1], [1, 1], [0, 0]])
29+
raw_counts = np.array([[1, 0], [0, 1], [1, 1], [4, 0]])
3030
adata = sc.AnnData(X=raw_counts, var=pd.DataFrame(index=genes))
3131
adata.raw = adata.copy() # Set raw data
3232

@@ -119,7 +119,7 @@ def test_query_openai(self, MockOpenAI):
119119
assert response.parsed_response == "parsed_response"
120120
mock_client.beta.chat.completions.parse.assert_called_once()
121121

122-
@pytest.mark.opanai()
122+
@pytest.mark.openai()
123123
def test_query_openai_actual(self):
124124
response = _query_openai(
125125
agent_description="Test agent",

tests/utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numpy as np
2+
import pandas as pd
3+
import scanpy as sc
4+
from anndata import AnnData
5+
from numpy.random import binomial, negative_binomial
6+
7+
# Declare the dictionary of expected marker genes
8+
expected_marker_genes = {
9+
"Neuron": ["MAP2", "NEFL", "RBFOX3", "SYN1", "GAP43", "DCX", "TUBB3", "NEUROD1", "STMN2", "ENO2"],
10+
"Fibroblast": ["COL1A1", "COL3A1", "VIM", "ACTA2", "FAP", "PDGFRA", "THY1", "FN1", "SPARC", "S100A4"],
11+
}
12+
13+
# Declare the neuronal and fibroblast cell types
14+
neuronal_cell_types = ["Neuron", "Neurons", "Neuronal cells", "neurons"]
15+
fibroblast_cell_types = ["Fibroblast", "Fibroblasts", "fibroblast cells"]
16+
17+
18+
def get_example_data(n_cells: int = 100, n_samples: int = 1) -> AnnData:
19+
"""Create example data for testing. Adapted from scanpy.
20+
21+
The data consists of two clusters with different marker genes. The first cluster is enriched for neuronal markers and the second cluster is enriched for fibroblast markers."""
22+
gene_names = expected_marker_genes["Neuron"] + expected_marker_genes["Fibroblast"]
23+
n_genes = len(gene_names)
24+
adata = AnnData(np.multiply(binomial(1, 0.15, (n_cells, n_genes)), negative_binomial(2, 0.25, (n_cells, n_genes))))
25+
adata.var_names = gene_names
26+
27+
# Create marker genes for the two clusters
28+
n_group_1 = np.floor(0.3 * n_cells).astype(int)
29+
n_group_2 = n_cells - n_group_1
30+
n_marker_genes = int(n_genes / 2)
31+
32+
adata.X[:n_group_1, :10] = np.multiply(
33+
binomial(1, 0.9, (n_group_1, n_marker_genes)), negative_binomial(1, 0.5, (n_group_1, n_marker_genes))
34+
)
35+
adata.X[n_group_1:, 10:] = np.multiply(
36+
binomial(1, 0.9, (n_group_2, n_marker_genes)), negative_binomial(1, 0.5, (n_group_2, n_marker_genes))
37+
)
38+
39+
# Create cluster according to groups
40+
adata.obs["leiden"] = pd.Categorical(np.concatenate((n_group_1 * ["0"], n_group_2 * ["1"])))
41+
42+
# Add sample information if there are multiple samples
43+
if n_samples > 1:
44+
samples = np.random.choice([f"sample_{i}" for i in range(n_samples)], size=n_cells)
45+
adata.obs["sample"] = samples
46+
47+
# filter, normalize and log transform the data
48+
sc.pp.filter_cells(adata, min_counts=2)
49+
adata.raw = adata.copy()
50+
sc.pp.normalize_total(adata, target_sum=1e4)
51+
sc.pp.log1p(adata)
52+
53+
return adata

0 commit comments

Comments
 (0)