Skip to content

Commit 8db75e6

Browse files
committed
make new localization file
1 parent af3cb5b commit 8db75e6

File tree

6 files changed

+172
-160
lines changed

6 files changed

+172
-160
lines changed

grassp/tests/test_plotting_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from grassp.plotting import clustering, heatmaps, integration, qc, ternary # noqa: E402
2323
from grassp.preprocessing import enrichment, simple # noqa: E402
24-
from grassp.tools import clustering as tl_clustering # noqa: E402
24+
from grassp.tools import localization as tl_localization # noqa: E402
2525
from grassp.tools import scoring # noqa: E402
2626

2727
# ==============================================================================
@@ -211,7 +211,7 @@ def test_knn_violin_smoke(self):
211211
)
212212

213213
# Run KNN annotation to get predictions
214-
tl_clustering.knn_annotation(
214+
tl_localization.knn_annotation(
215215
adata, gt_col="markers", key_added="knn_pred", min_probability=0
216216
)
217217

grassp/tests/test_tools_integration.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
clustering,
2525
enrichment,
2626
integration,
27+
localization,
2728
scoring,
2829
tagm,
2930
)
@@ -273,7 +274,7 @@ def test_knn_annotation_basic(self):
273274
n_proteins=100, marker_fraction=0.3, add_neighbors=True
274275
)
275276

276-
clustering.knn_annotation(
277+
localization.knn_annotation(
277278
adata,
278279
gt_col="markers",
279280
key_added="knn_annotation",
@@ -294,7 +295,7 @@ def test_knn_annotation_fix_markers(self):
294295
n_proteins=100, marker_fraction=0.3, add_neighbors=True
295296
)
296297

297-
clustering.knn_annotation(
298+
localization.knn_annotation(
298299
adata,
299300
gt_col="markers",
300301
key_added="knn_fixed",
@@ -632,7 +633,7 @@ def test_knn_f1_score_with_prediction(self):
632633
)
633634

634635
# First create predictions (use min_probability=0 to get all predictions)
635-
clustering.knn_annotation(
636+
localization.knn_annotation(
636637
adata, gt_col="markers", key_added="predictions", min_probability=0
637638
)
638639

@@ -1220,7 +1221,7 @@ def test_clustering_annotation_workflow(self):
12201221
assert "mc_cluster" in adata.obs.columns
12211222

12221223
# Step 2: KNN annotation
1223-
clustering.knn_annotation(
1224+
localization.knn_annotation(
12241225
adata,
12251226
gt_col="markers",
12261227
key_added="knn_annotation",
@@ -1305,7 +1306,7 @@ def test_knn_annotation_missing_column(self):
13051306
adata = make_enriched_data_with_structure(n_proteins=50, add_neighbors=True)
13061307

13071308
with pytest.raises(KeyError):
1308-
clustering.knn_annotation(adata, gt_col="nonexistent_column")
1309+
localization.knn_annotation(adata, gt_col="nonexistent_column")
13091310

13101311
def test_silhouette_score_missing_embedding(self):
13111312
"""Test error when embedding not found."""

grassp/tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from .clustering import (
22
calculate_interfacialness_score,
33
get_n_nearest_neighbors,
4-
knn_annotation,
54
leiden_mito_sweep,
65
markov_clustering,
76
to_knn_graph,
87
)
98
from .enrichment import calculate_cluster_enrichment, rank_proteins_groups
109
from .integration import align_adatas, aligned_umap, mr_score, remodeling_score
10+
from .localization import knn_annotation, knn_annotation_old
1111
from .scoring import (
1212
calinski_habarasz_score,
1313
class_balance,

grassp/tools/clustering.py

Lines changed: 2 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import pandas as pd
1212
import scanpy as sc
1313

14+
from .localization import _get_knn_annotation_df
15+
1416

1517
def _get_clusters(matrix):
1618
# get the attractors - non-zero elements of the matrix diagonal
@@ -149,157 +151,6 @@ def leiden_mito_sweep(
149151
data.uns["leiden"]["mito_majority_fraction"] = mito_majority_fraction
150152

151153

152-
def _get_knn_annotation_df(
153-
data: AnnData, obs_ann_col: str, exclude_category: str | List[str] | None = None
154-
) -> pd.DataFrame:
155-
"""
156-
Get a dataframe with a column of .obs repeated for each protein.
157-
"""
158-
nrow = data.obs.shape[0]
159-
obs_ann = data.obs[obs_ann_col]
160-
if isinstance(exclude_category, str):
161-
exclude_category = [exclude_category]
162-
if exclude_category is not None:
163-
obs_ann.replace(exclude_category, np.nan, inplace=True)
164-
165-
df = pd.DataFrame(np.tile(obs_ann, (nrow, 1)))
166-
return df
167-
168-
169-
def knn_annotation(
170-
data,
171-
gt_col,
172-
fix_markers=False,
173-
class_balance=True,
174-
min_probability=0.5,
175-
inplace=True,
176-
obsp_key="connectivities",
177-
key_added="knn_annotation",
178-
):
179-
"""Propagate categorical annotations along the *k*-NN graph.
180-
181-
For each observation the function inspects its neighbourhood in
182-
``adata.obsp[obsp_key]`` (generated by :func:`scanpy.pp.neighbors`) and
183-
calculates the a weighted probability for each label category.
184-
185-
Parameters
186-
----------
187-
data
188-
:class:`anndata.AnnData` with a populated neighbour graph (*distances*
189-
or *connectivities*).
190-
gt_col
191-
Observation column containing the *source* annotations to be
192-
propagated.
193-
fix_markers
194-
If ``True`` marker probabilities do not get overwritten by the propagated labels.
195-
class_balance
196-
If ``True`` ground truth compartments with a lot of proteins are downweighted proportional to their size to prevent them from dominating the propagated labels.
197-
min_probability
198-
If the probability of the most probable label is below this threshold, the label is set to ``np.nan``.
199-
obsp_key
200-
Name of the neighbour connectivity graph to use (default ``"connectivities"``).
201-
key_added
202-
Name of the new column that will hold the propagated annotation
203-
(default ``"knn_annotation"``).
204-
205-
Returns
206-
-------
207-
Modified anndata object with the following new entries:
208-
- .obsm[f"{key_added}_probabilities"] containing the propagated probabilities
209-
- .obs[f"{key_added}"] containing the propagated labels (most probable label)
210-
- .uns[f"{key_added}_colors"] to make sure plotting uses the same colors as the ground truth labels
211-
- .obs[f"{key_added}_probability"] containing the probability of the most probable label
212-
"""
213-
labels = data.obs[gt_col].astype("category")
214-
labels_one_hot = pd.get_dummies(labels).values
215-
T = data.obsp[obsp_key]
216-
# Propagate the labels with transition matrix T
217-
Y = T @ labels_one_hot
218-
Y[Y.sum(axis=1) == 0] = 1 / Y.shape[1]
219-
220-
# Class balance
221-
if class_balance:
222-
# gt_compartments with a lot of proteins are more likely to be in the neighborhood of a protein
223-
# Adjust probability based on the number of proteins in the compartment
224-
Y = Y / np.nansum(Y, axis=0) * labels_one_hot.sum(axis=0)
225-
#
226-
# Normalize the propagated labels to get probabilities
227-
if any(Y.sum(axis=1) == 0):
228-
print(Y[Y.sum(axis=1) == 0])
229-
Y = Y / np.nansum(Y, axis=1)[:, None]
230-
231-
if fix_markers:
232-
# Set markers to 1
233-
marker_mask = labels_one_hot.sum(axis=1) == 1
234-
Y[marker_mask] = labels_one_hot[marker_mask].astype(float)
235-
236-
if inplace:
237-
data.obsm[f"{key_added}_probabilities"] = Y
238-
data.obsm[f"{key_added}_one_hot_labels"] = labels_one_hot
239-
data.obs[f"{key_added}"] = labels.cat.categories[Y.argmax(axis=1)]
240-
data.obs[f"{key_added}_probability"] = np.max(Y, axis=1)
241-
data.obs.loc[
242-
data.obs[f"{key_added}_probability"] < min_probability, f"{key_added}"
243-
] = np.nan
244-
if f"{gt_col}_colors" in data.uns:
245-
data.uns[f"{key_added}_colors"] = data.uns[f"{gt_col}_colors"]
246-
else:
247-
return {
248-
"probabilities": Y,
249-
"labels": labels.cat.categories,
250-
"one_hot_labels": labels_one_hot,
251-
}
252-
253-
254-
def knn_annotation_old(
255-
data: AnnData,
256-
obs_ann_col: str,
257-
key_added: str = "consensus_graph_annotation",
258-
exclude_category: str | List[str] | None = None,
259-
inplace: bool = True,
260-
) -> AnnData | None:
261-
"""Propagate categorical annotations along the *k*-NN graph.
262-
263-
For each observation the function inspects its neighbourhood in
264-
``adata.obsp['distances']`` (generated by :func:`scanpy.pp.neighbors`) and
265-
assigns the majority category found in ``obs_ann_col``. Ties are broken
266-
arbitrarily using :func:`pandas.DataFrame.mode`.
267-
268-
Parameters
269-
----------
270-
data
271-
:class:`anndata.AnnData` with a populated neighbour graph (*distances*
272-
or *connectivities*).
273-
obs_ann_col
274-
Observation column containing the *source* annotations to be
275-
propagated.
276-
key_added
277-
Name of the new column that will hold the *consensus* annotation
278-
(default ``"consensus_graph_annotation"``).
279-
exclude_category
280-
One or multiple category labels that should be ignored when computing
281-
the neighbourhood majority (useful for *unknown* / *NA* categories).
282-
inplace
283-
If ``True`` (default) modify *data* in place. Otherwise return a
284-
copy with the additional column.
285-
286-
Returns
287-
-------
288-
Modified object when ``inplace`` is ``False`` with a new column in .obs[key_added].
289-
"""
290-
df = _get_knn_annotation_df(data, obs_ann_col, exclude_category)
291-
292-
conn = data.obsp["distances"]
293-
mask = ~(conn != 0).todense() # This avoids expensive conn == 0 for sparse matrices
294-
df[mask] = np.nan
295-
296-
majority_cluster = df.mode(axis=1, dropna=True).loc[
297-
:, 0
298-
] # take the first if there are ties
299-
data.obs[key_added] = majority_cluster.values
300-
return data if not inplace else None
301-
302-
303154
def to_knn_graph(
304155
data: AnnData,
305156
node_label_column: str | None = None,

0 commit comments

Comments
 (0)