Skip to content

Commit 59fdec4

Browse files
authored
Merge pull request #1 from czbiohub-sf/dev
Markov clustering, ternary plots, update license metadata
2 parents 3f5e566 + 4210723 commit 59fdec4

File tree

8 files changed

+151
-13
lines changed

8 files changed

+151
-13
lines changed

grassp/plotting/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# from .heatmaps import grouped_heatmap
22
from .heatmaps import protein_clustermap, sample_heatmap, qsep_heatmap, qsep_boxplot
3-
from .integration import aligned_umap, remodeling_sankey, remodeling_score
3+
from .integration import aligned_umap, remodeling_sankey, remodeling_score, mr_plot
44
from .qc import bait_volcano_plots, highly_variable_proteins
55
from .clustering import tagm_map_contours, tagm_map_pca_ellipses
6+
from .ternary import ternary

grassp/plotting/heatmaps.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ def qsep_heatmap(
210210
normalize: bool = True,
211211
ax: plt.Axes = None,
212212
cmap: str = "RdBu_r",
213+
vmin: float = None,
214+
vmax: float = None,
213215
**kwargs,
214216
) -> plt.Axes:
215217
"""Plot QSep cluster distance heatmap.
@@ -248,12 +250,17 @@ def qsep_heatmap(
248250
# Normalize by diagonal values
249251
norm_distances = distances / np.diag(distances)[:, np.newaxis]
250252
plot_data = norm_distances[::-1]
251-
vmin = 1.0
252-
vmax = np.max(norm_distances)
253+
tvmin = 1.0
254+
tvmax = np.max(norm_distances)
253255
else:
254256
plot_data = distances[::-1]
255-
vmin = None
256-
vmax = None
257+
tvmin = None
258+
tvmax = None
259+
260+
if vmin is None:
261+
vmin = tvmin
262+
if vmax is None:
263+
vmax = tvmax
257264

258265
# Create heatmap
259266
sns.heatmap(

grassp/plotting/ternary.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from __future__ import annotations
2+
from typing import TYPE_CHECKING
3+
4+
if TYPE_CHECKING:
5+
from typing import Optional, List
6+
from anndata import AnnData
7+
8+
import pandas as pd
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
12+
from scanpy.plotting._tools.scatterplots import (
13+
_color_vector,
14+
_get_color_source_vector,
15+
_add_categorical_legend,
16+
_get_palette,
17+
)
18+
19+
20+
def ternary(
21+
adata: AnnData,
22+
color: Optional[str] = None,
23+
ax=None,
24+
labels: Optional[List[str]] = None,
25+
show: bool = True,
26+
colorbar_loc: Optional[str] = None,
27+
legend_loc: Optional[str] = None,
28+
legend_fontweight: Optional[str] = None,
29+
legend_fontsize: Optional[int] = None,
30+
legend_fontoutline: Optional[str] = None,
31+
na_in_legend: Optional[bool] = None,
32+
**kwargs,
33+
):
34+
try:
35+
import mpltern
36+
except ImportError:
37+
raise ImportError(
38+
"mpltern is not installed. Please install it with `pip install mpltern`"
39+
)
40+
if adata.X.shape[1] != 3:
41+
raise ValueError("Ternary plots requires adata object with 3 samples (columns)")
42+
if ax is None:
43+
ax = plt.subplot(projection="ternary")
44+
if labels is None:
45+
labels = adata.var_names
46+
47+
csv = _get_color_source_vector(adata, color)
48+
49+
cv, color_type = _color_vector(adata, values_key=color, values=csv, palette=None)
50+
51+
# Make sure that nan values are plottted below the other points
52+
nan_mask = np.isnan(csv) if isinstance(csv, np.ndarray) else csv.isna()
53+
if nan_mask.any():
54+
nan_points = adata[nan_mask].X
55+
ax.scatter(
56+
nan_points[:, 0],
57+
nan_points[:, 1],
58+
nan_points[:, 2],
59+
c=cv[nan_mask],
60+
**kwargs,
61+
zorder=0,
62+
)
63+
cax = ax.scatter(
64+
adata.X[~nan_mask, 0],
65+
adata.X[~nan_mask, 1],
66+
adata.X[~nan_mask, 2],
67+
zorder=1,
68+
c=cv[~nan_mask],
69+
**kwargs,
70+
)
71+
ax.taxis.set_label_position("tick1")
72+
ax.raxis.set_label_position("tick1")
73+
ax.laxis.set_label_position("tick1")
74+
ax.set_tlabel(labels[0])
75+
ax.set_llabel(labels[1])
76+
ax.set_rlabel(labels[2])
77+
78+
if color_type == "cat":
79+
_add_categorical_legend(
80+
ax,
81+
csv,
82+
palette=_get_palette(adata, color),
83+
scatter_array=None,
84+
legend_loc=legend_loc,
85+
legend_fontweight=legend_fontweight,
86+
legend_fontsize=legend_fontsize,
87+
legend_fontoutline=legend_fontoutline,
88+
na_color="grey",
89+
na_in_legend=na_in_legend,
90+
multi_panel=False,
91+
)
92+
elif colorbar_loc is not None:
93+
plt.colorbar(
94+
cax, ax=ax, pad=0.01, fraction=0.08, aspect=30, location=colorbar_loc
95+
)
96+
if show:
97+
plt.show()
98+
return ax

grassp/preprocessing/enrichment.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@
1111
import warnings
1212

1313

14+
def _check_covariates(data: AnnData, covariates: Optional[list[str]] = None):
15+
if covariates is None:
16+
covariates = data.var.columns[data.var.columns.str.startswith("covariate_")]
17+
# Check that all covariates are in the data
18+
for c in covariates:
19+
if c not in data.var.columns:
20+
raise ValueError(f"Covariate {c} not found in data.var.columns")
21+
22+
if not isinstance(covariates, list):
23+
covariates = [covariates]
24+
return covariates
25+
26+
1427
def calculate_enrichment_vs_untagged(
1528
data: AnnData,
1629
covariates: Optional[list[str]] = [],
@@ -120,6 +133,17 @@ def calculate_enrichment_vs_untagged(
120133
return data_aggr
121134

122135

136+
def calculate_noc_proportions(
137+
adata: AnnData,
138+
covariates: Optional[list[str]] = None,
139+
subcellular_enrichment_column: str = "subcellular_enrichment",
140+
use_layer: Optional[str] = None,
141+
original_intensities_key: str | None = None,
142+
keep_raw: bool = True,
143+
) -> AnnData:
144+
pass
145+
146+
123147
def calculate_enrichment_vs_all(
124148
adata: AnnData,
125149
covariates: Optional[list[str]] = None,
@@ -153,9 +177,10 @@ def calculate_enrichment_vs_all(
153177

154178
data = adata.copy()
155179

156-
# if covariates is None:
157-
# covariates = data.var.columns[data.var.columns.str.startswith("covariate_")]
158-
# else:
180+
if covariates is None:
181+
covariates = data.var.columns[
182+
data.var.columns.str.startswith("covariate_")
183+
].tolist()
159184
# Check that all covariates are in the data
160185
for c in covariates:
161186
if c not in data.var.columns:
@@ -198,7 +223,7 @@ def calculate_enrichment_vs_all(
198223
lfc = np.median(intensities_ip, axis=1) - np.median(intensities_control, axis=1)
199224
aggr_mask = data_aggr.var["_experimental_condition"] == experimental_condition
200225
data_aggr.layers["pvals"][:, aggr_mask] = pv[:, None]
201-
data_aggr[:, aggr_mask].X = lfc[:, None]
226+
data_aggr.X[:, aggr_mask] = lfc[:, None]
202227
data_aggr.var.loc[aggr_mask, "enriched_vs"] = ",".join(
203228
data_aggr.var_names[control_mask]
204229
)

grassp/preprocessing/simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def normalize_total(
532532

533533
def drop_excess_MQ_metadata(
534534
data: AnnData,
535-
colname_regex: str = "Peptide|peptide|MS/MS|Evidence IDs|Taxonomy|Oxidation|Intensity|Identification type|Sequence coverage|MS/MS count",
535+
colname_regex: str = "Peptide|peptide|MS/MS|Evidence IDs|Taxonomy|Oxidation|Intensity|Total Spectral Count|Unique Spectral Count|Spectral Count|Identification type|Sequence coverage|MS/MS count",
536536
inplace: bool = True,
537537
) -> AnnData | None:
538538
"""Drop excess metadata columns from MaxQuant output.

grassp/tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
tagm_map_predict,
1010
)
1111
from .enrichment import calculate_cluster_enrichment, rank_proteins_groups
12-
from .integration import align_adatas, aligned_umap, remodeling_score
12+
from .integration import align_adatas, aligned_umap, remodeling_score, mr_score
1313
from .scoring import (
1414
calinski_habarasz_score,
1515
class_balance,

grassp/tools/scoring.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,16 @@ def qsep_score(
205205
}
206206

207207
for i, cluster1 in enumerate(valid_clusters):
208-
for j, cluster2 in enumerate(valid_clusters):
208+
for j in range(i, len(valid_clusters)):
209+
# for j, cluster2 in enumerate(valid_clusters[i + 1 :]):
210+
cluster2 = valid_clusters[j]
209211
idx1 = cluster_indices[cluster1]
210212
idx2 = cluster_indices[cluster2]
211213

212214
# Get submatrix of distances between clusters
213215
submatrix = full_distances[np.ix_(idx1, idx2)]
214216
cluster_distances[i, j] = np.mean(submatrix)
217+
cluster_distances[j, i] = np.mean(submatrix)
215218

216219
if inplace:
217220
# Store full distances

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ requires = [
99
name = "grassp"
1010
description = "A python package to facilitate Organellar profiling"
1111
readme = "README.md"
12-
license = {file = "LICENSE"}
12+
license = {text = "BSD 3-Clause License"}
1313
requires-python = ">=3.7, <4.0"
1414

1515
# the dynamically determined project metadata attributes
@@ -38,6 +38,7 @@ dependencies = [
3838
"umap-learn",
3939
"pysankeybeta",
4040
"gseapy",
41+
"markov_clustering",
4142
]
4243

4344

@@ -84,6 +85,9 @@ packages = ["grassp"]
8485
# allow use of __file__ to load data files included in the package
8586
zip-safe = false
8687

88+
# Don't include LICENSE as a license-file in the metadata
89+
license-files = []
90+
8791

8892
[tool.black]
8993
line-length = 95

0 commit comments

Comments
 (0)