Skip to content

Commit 3ef4270

Browse files
authored
Fix (#1277)
1 parent fa40a6a commit 3ef4270

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,17 @@ def predict_obs_metadata(
160160
# TODO: something more intelligent for numeric columns! also use distances, etc.
161161
max_joinid = neighbor_obs.index.max()
162162
out: dict[str, pd.Series[Any]] = {}
163-
indices = np.broadcast_to(np.arange(neighbors.neighbor_ids.shape[0]), (10, neighbors.neighbor_ids.shape[0])).T
163+
n_queries, n_neighbors = neighbors.neighbor_ids.shape
164+
indices = np.broadcast_to(np.arange(n_queries), (n_neighbors, n_queries)).T
164165
g = sparse.csr_matrix(
165166
(
166-
np.broadcast_to(1, neighbors.neighbor_ids.shape[0] * 10),
167+
np.broadcast_to(1, n_queries * n_neighbors),
167168
(
168169
indices.flatten(),
169170
neighbors.neighbor_ids.astype(np.int64).flatten(),
170171
),
171172
),
172-
shape=(neighbors.neighbor_ids.shape[0], max_joinid + 1),
173+
shape=(n_queries, max_joinid + 1),
173174
)
174175
for col in column_names:
175176
col_categorical = neighbor_obs[col].astype("category")

api/python/cellxgene_census/tests/experimental/test_embeddings_search.py

+20
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,26 @@ def test_embeddings_search(true_neighbors: dict[str, Any], query_result: Neighbo
3939
return
4040

4141

42+
@pytest.mark.experimental
43+
@pytest.mark.live_corpus
44+
@pytest.mark.parametrize("n_neighbors", [5, 7, 20])
45+
def test_embedding_search_n_neighbors(query_anndata: ad.AnnData, n_neighbors: int) -> None:
46+
columns = ["cell_type"]
47+
result = find_nearest_obs(
48+
TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME,
49+
TRUE_NEAREST_NEIGHBORS_ORGANISM,
50+
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
51+
query_anndata,
52+
k=n_neighbors,
53+
nprobe=25,
54+
)
55+
56+
# Check that the correct number of neighbors is being returned
57+
assert result.neighbor_ids.shape[1] == n_neighbors
58+
# Check that this step works
59+
_ = predict_obs_metadata(TRUE_NEAREST_NEIGHBORS_ORGANISM, TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, result, columns)
60+
61+
4262
@pytest.mark.experimental
4363
@pytest.mark.live_corpus
4464
def test_embeddings_search_errors(query_anndata: ad.AnnData) -> None:

0 commit comments

Comments
 (0)