Skip to content

Commit 4480a76

Browse files
authored
Merge pull request #222 from scverse/add_bsize_psbulk
Add bsize psbulk
2 parents 5b29dc6 + dc695d6 commit 4480a76

3 files changed

Lines changed: 16 additions & 9 deletions

File tree

src/decoupler/mt/_gsea.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def _nesrank(
110110
neg_null_mean = null[neg_null_msk].mean()
111111
nes = -es / neg_null_mean
112112
else:
113-
nes = np.inf
114-
pval = np.inf
113+
nes = 0.0
114+
pval = 1.0
115115
return nes, pval
116116

117117

src/decoupler/pl/_filter_samples.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@docs.dedent
1313
def filter_samples(
1414
adata: AnnData,
15-
groupby: str | list,
15+
groupby: str | list | None = None,
1616
log: bool = True,
1717
min_cells: int | float = 10,
1818
min_counts: int | float = 1000,
@@ -51,10 +51,11 @@ def filter_samples(
5151
assert all(col in adata.obs.columns for col in ["psbulk_cells", "psbulk_counts"]), (
5252
"psbulk_* columns not present in adata.obs, this function should be used after running decoupler.pp.pseudobulk"
5353
)
54-
assert isinstance(groupby, str | list), "groupby must be str or list"
54+
assert isinstance(groupby, str | list) or groupby is None, "groupby must be str, list or None"
5555
if isinstance(groupby, str):
5656
groupby = [groupby]
57-
assert all(col in adata.obs for col in groupby), "columns in groupby must be in adata.obs"
57+
if groupby:
58+
assert all(col in adata.obs for col in groupby), "columns in groupby must be in adata.obs"
5859
# Extract obs
5960
df = adata.obs.copy()
6061
# Transform to log10
@@ -65,7 +66,7 @@ def filter_samples(
6566
label_x, label_y = r"$\log_{10}$ " + label_x, r"$\log_{10}$ " + label_y
6667
min_cells, min_counts = np.log10(min_cells), np.log10(min_counts)
6768
# Plot
68-
if len(groupby) > 1:
69+
if groupby is not None and (len(groupby) > 1):
6970
# Instance
7071
assert kwargs.get("ax") is None, "when groupby is list, ax must be None"
7172
kwargs["ax"] = None
@@ -85,12 +86,14 @@ def filter_samples(
8586
ax.axhline(y=min_counts, linestyle="--", color="black")
8687
else:
8788
# Instance
88-
groupby = groupby[0]
89+
if isinstance(groupby, list):
90+
groupby = groupby[0]
8991
bp = Plotter(**kwargs)
9092
bp.ax.grid(zorder=0)
9193
bp.ax.set_axisbelow(True)
9294
sns.scatterplot(x="psbulk_cells", y="psbulk_counts", hue=groupby, ax=bp.ax, data=df, zorder=1)
93-
bp.ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, title=groupby)
95+
if groupby:
96+
bp.ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, title=groupby)
9497
bp.ax.set_xlabel(label_x)
9598
bp.ax.set_ylabel(label_y)
9699
bp.ax.axvline(x=min_cells, linestyle="--", color="black")

src/decoupler/pp/anndata.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def pseudobulk(
311311
empty: bool = False,
312312
mode: str | Callable | dict = "sum",
313313
skip_checks: bool = False,
314+
bsize: int = 250_000,
314315
verbose: bool = False,
315316
) -> AnnData:
316317
"""
@@ -365,7 +366,10 @@ def pseudobulk(
365366
assert isinstance(groups_col, str | list) or groups_col is None, "groups_col must be str or None"
366367
assert isinstance(mode, str | dict) or callable(mode), "mode must be str, dict or callable"
367368
# Extract data
368-
X, obs, var = extract(adata, layer=layer, raw=raw, empty=empty, verbose=verbose)
369+
X, obs, var = extract(adata, layer=layer, raw=raw, empty=empty, bsize=bsize, verbose=verbose)
370+
assert len(set(obs)) == len(obs), (
371+
"Repeated elements in adata.obs_names, to make them unique run adata.obs_names_make_unique()"
372+
)
369373
obs = adata.obs.loc[obs].copy()
370374
var = adata.var.loc[var]
371375
# Validate X

0 commit comments

Comments
 (0)