Skip to content

Commit 49f5a7e

Browse files
committed
add qsep score
1 parent 131e4e6 commit 49f5a7e

File tree

4 files changed

+246
-2
lines changed

4 files changed

+246
-2
lines changed

grassp/plotting/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# from .heatmaps import grouped_heatmap
2-
from .heatmaps import protein_clustermap, sample_heatmap
2+
from .heatmaps import protein_clustermap, sample_heatmap, qsep_heatmap, qsep_boxplot
33
from .integration import aligned_umap, remodeling_sankey, remodeling_score
44
from .qc import bait_volcano_plots, highly_variable_proteins

grassp/plotting/heatmaps.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,167 @@ def sample_heatmap(
203203
# swap_axes=True,
204204
# show=False,
205205
# )
206+
207+
208+
def qsep_heatmap(
209+
data: AnnData,
210+
normalize: bool = True,
211+
ax: plt.Axes = None,
212+
cmap: str = "RdBu_r",
213+
**kwargs,
214+
) -> plt.Axes:
215+
"""Plot QSep cluster distance heatmap.
216+
217+
Parameters
218+
----------
219+
data : AnnData
220+
Annotated data matrix containing QSep results.
221+
normalize : bool, optional
222+
If True, normalize distances by diagonal values.
223+
Defaults to True.
224+
ax : matplotlib.axes.Axes, optional
225+
Axes to plot on. If None, current axes will be used.
226+
cmap : str, optional
227+
Colormap to use. Defaults to "RdBu_r".
228+
**kwargs
229+
Additional arguments passed to sns.heatmap.
230+
231+
Returns
232+
-------
233+
matplotlib.axes.Axes
234+
The axes object with the plot.
235+
"""
236+
if ax is None:
237+
ax = plt.gca()
238+
239+
try:
240+
distances = data.uns["cluster_distances"]["distances"]
241+
clusters = data.uns["cluster_distances"]["clusters"]
242+
except KeyError:
243+
raise ValueError(
244+
"Cluster distances not found in data.uns['cluster_distances'], run gr.tl.qsep_score first"
245+
)
246+
247+
if normalize:
248+
# Normalize by diagonal values
249+
norm_distances = distances / np.diag(distances)[:, np.newaxis]
250+
plot_data = norm_distances[::-1]
251+
vmin = 1.0
252+
vmax = np.max(norm_distances)
253+
else:
254+
plot_data = distances[::-1]
255+
vmin = None
256+
vmax = None
257+
258+
# Create heatmap
259+
sns.heatmap(
260+
plot_data,
261+
xticklabels=clusters,
262+
yticklabels=clusters[::-1], # Reverse the y-axis labels
263+
cmap=cmap,
264+
vmin=vmin,
265+
vmax=vmax,
266+
ax=ax,
267+
**kwargs,
268+
)
269+
270+
ax.set_title("QSep Cluster Distances" + (" (Normalized)" if normalize else ""))
271+
272+
return ax
273+
274+
275+
def qsep_boxplot(
276+
data: AnnData,
277+
normalize: bool = True,
278+
ax: plt.Axes = None,
279+
palette: str = "Set2",
280+
**kwargs,
281+
) -> plt.Axes:
282+
"""Plot QSep cluster distances as boxplots.
283+
284+
Parameters
285+
----------
286+
data : AnnData
287+
Annotated data matrix containing QSep results.
288+
normalize : bool, optional
289+
If True, normalize distances by diagonal values.
290+
Defaults to True.
291+
ax : matplotlib.axes.Axes, optional
292+
Axes to plot on. If None, current axes will be used.
293+
palette : str, optional
294+
Color palette for the boxplots. Defaults to "Set2".
295+
**kwargs
296+
Additional arguments passed to sns.boxplot.
297+
298+
Returns
299+
-------
300+
matplotlib.axes.Axes
301+
The axes object with the plot.
302+
"""
303+
if ax is None:
304+
ax = plt.gca()
305+
306+
try:
307+
distances = data.uns["cluster_distances"]["distances"]
308+
clusters = data.uns["cluster_distances"]["clusters"]
309+
except KeyError:
310+
raise ValueError(
311+
"Cluster distances not found in data.uns['cluster_distances'], run gr.tl.qsep_score first"
312+
)
313+
314+
if normalize:
315+
# Normalize by diagonal values
316+
distances = distances / np.diag(distances)[:, np.newaxis]
317+
318+
# Create DataFrame for plotting
319+
plot_data = []
320+
for i, ref_cluster in enumerate(clusters):
321+
for j, target_cluster in enumerate(clusters):
322+
plot_data.append(
323+
{
324+
"Reference Cluster": ref_cluster,
325+
"Target Cluster": target_cluster,
326+
"Distance": distances[i, j],
327+
"color": "grey" if i == j else "red",
328+
}
329+
)
330+
plot_df = pd.DataFrame(plot_data)
331+
print(plot_df)
332+
333+
# Create boxplot
334+
sns.boxplot(
335+
data=plot_df,
336+
x="Distance",
337+
y="Reference Cluster",
338+
color="grey",
339+
orient="h",
340+
ax=ax,
341+
legend=False,
342+
showfliers=False,
343+
**kwargs,
344+
)
345+
346+
# Add individual points
347+
sns.stripplot(
348+
data=plot_df,
349+
x="Distance",
350+
y="Reference Cluster",
351+
hue="color",
352+
# hue="Target Cluster",
353+
orient="h",
354+
size=4,
355+
# color=".3",
356+
alpha=0.6,
357+
ax=ax,
358+
legend=False,
359+
)
360+
361+
# Customize plot
362+
if normalize:
363+
ax.axvline(x=1 if normalize else 0, color="gray", linestyle="--", alpha=0.5)
364+
ax.set_xlabel("QSep Cluster Distances" + (" (Normalized)" if normalize else ""))
365+
366+
# Move legend outside
367+
# ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0)
368+
369+
return ax

grassp/tools/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,9 @@
77
)
88
from .enrichment import calculate_cluster_enrichment, rank_proteins_groups
99
from .integration import align_adatas, aligned_umap, remodeling_score
10-
from .scoring import calinski_habarasz_score, silhouette_score
10+
from .scoring import (
11+
calinski_habarasz_score,
12+
class_balance,
13+
qsep_score,
14+
silhouette_score,
15+
)

grassp/tools/scoring.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,78 @@ def calinski_habarasz_score(
8383
data.uns[key_added] = ch
8484
else:
8585
return ch
86+
87+
88+
def qsep_score(
89+
data: AnnData,
90+
label_key: str,
91+
use_rep: str = "X",
92+
distance_key: str = "full_distances",
93+
inplace: bool = True,
94+
) -> None | np.ndarray:
95+
"""Calculate QSep scores for spatial proteomics data.
96+
97+
Parameters
98+
----------
99+
data : AnnData
100+
Annotated data matrix.
101+
label_key : str
102+
Key in data.obs containing cluster labels.
103+
use_rep : str, optional
104+
Key for representation to use for distance calculation.
105+
Either 'X' or a key in data.obsm. Defaults to 'X'.
106+
distance_key : str, optional
107+
Key under which to store the full distances in data.obs.
108+
Defaults to 'full_distances'.
109+
inplace : bool, optional
110+
If True, store results in data, else return matrices.
111+
Defaults to True.
112+
113+
Returns
114+
-------
115+
None or np.ndarray
116+
If inplace=True, returns None and stores results in data.
117+
If inplace=False, returns cluster_distances.
118+
"""
119+
# Get data matrix
120+
if use_rep == "X":
121+
X = data.X
122+
else:
123+
X = data.obsm[use_rep]
124+
125+
# Calculate pairwise distances between all points
126+
full_distances = sklearn.metrics.pairwise_distances(X)
127+
128+
# Get valid clusters (non-NA)
129+
mask = data.obs[label_key].notna()
130+
valid_clusters = data.obs[label_key][mask].unique()
131+
132+
# Calculate cluster distances
133+
cluster_distances = np.zeros((len(valid_clusters), len(valid_clusters)))
134+
cluster_indices = {
135+
cluster: np.where(data.obs[label_key] == cluster)[0]
136+
for cluster in valid_clusters
137+
}
138+
139+
for i, cluster1 in enumerate(valid_clusters):
140+
for j, cluster2 in enumerate(valid_clusters):
141+
idx1 = cluster_indices[cluster1]
142+
idx2 = cluster_indices[cluster2]
143+
144+
# Get submatrix of distances between clusters
145+
submatrix = full_distances[np.ix_(idx1, idx2)]
146+
cluster_distances[i, j] = np.mean(submatrix)
147+
148+
if inplace:
149+
# Store full distances
150+
data.obs[distance_key] = pd.Series(
151+
np.mean(full_distances, axis=1), index=data.obs.index
152+
)
153+
154+
# Store cluster distances and metadata
155+
data.uns["cluster_distances"] = {
156+
"distances": cluster_distances,
157+
"clusters": valid_clusters.tolist(),
158+
}
159+
else:
160+
return cluster_distances

0 commit comments

Comments
 (0)