Skip to content

Commit 500cd6c

Browse files
committed
Fixed batching and added rapids-singlecell benchmarks
1 parent 791756b commit 500cd6c

5 files changed

Lines changed: 67 additions & 28 deletions

File tree

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# illico
22
`illico` is a python library performing fast and lightweight wilcoxon rank-sum tests (same as `scanpy.tl.rank_genes_groups(…, method="wilcoxon")`), useful for single-cell RNASeq data analyses and processing.
3-
Approximate speed benchmarks (done on a 8-CPUs machine) ran on k562-essential can be found below.
3+
Approximate speed benchmarks (done on a 8-CPUs, 1 GPU machine) ran on k562-essential (~300k cells, 8k genes, 2k perturbations) can be found below.
44

5-
| Test | Format | illico | scanpy | pdex |
6-
|----------------------------------|--------|--------|--------|------|
7-
| OVO (reference="non-targeting") | Dense | ~20s | ~1h | ~20min |
8-
| OVO (reference="non-targeting") | Sparse | ~15s | ~1h30 | ~8min |
9-
| OVR (reference=None) | Dense | ~10s | >10h | >10h |
10-
| OVR (reference=None) | Sparse | ~10s | >10h | >10h |
5+
| Test | Format | illico | scanpy | pdex | rapids-singlecell (GPU) |
6+
|----------------------------------|--------|--------|--------|------|------------------ |
7+
| OVO (reference="non-targeting") | Dense | ~20s | ~1h | ~20min | ~25min |
8+
| OVO (reference="non-targeting") | Sparse | ~15s | ~1h30min | ~8min | ~1h10min |
9+
| OVR (reference=None) | Dense | ~10s | >10h | >10h | ~1min |
10+
| OVR (reference=None) | Sparse | ~10s | >10h | >10h | ~1min |
1111

1212
## Installation
1313
illico is compatible with python 3.11 and onward:
@@ -30,7 +30,7 @@ de_genes = asymptotic_wilcoxon(
3030
group_keys="perturbation",
3131
reference=["non-targeting"|None], # <- `None` computes cluster-wise DE genes. Any other `str` will be interpreted as label of the control cells.
3232
is_log1p=[False|True], # <-- Specify if your data underwent log1p or not
33-
return_as_scanpy=[False|True], # <-- Whether to return a dict compatible with Scanpy's `rank_genes_groups` function, or a pd.DataFrame
33+
return_as_scanpy=[False|True], # <-- Whether to return a dict compatible with Scanpy's `rank_genes_groups` function, or a pd.DataFrame holding all p-values, statistics, and fold-change
3434
)
3535
# Eventually, if return_as_scanpy=True:
3636
adata.uns["rank_genes_groups"] = de_genes

changelog.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
Changelog
22
=========
33

4+
Version 0.4.0
5+
------------
6+
- Added option to return scanpy-friendly output with `return_as_scanpy` arg. `asymptotic_wilcoxon` returns either:
7+
- A `pandas.DataFrame` with columns `feature`, `p_value`, `fold_change`, and `statistic` (default), if `return_as_scanpy=False`
8+
- A dictionary containing the same keys as `scanpy.tl.rank_genes_groups`, if `return_as_scanpy=True`. Similarly as scanpy, genes are ordered by decreasing z-score.
9+
- Improved the batching mechanism, fixed the 'auto' mode that was excluding the very last gene in previous versions.
10+
411
Version 0.3.0
512
------------
613
- Rust backend is available for all tests. Compare Rust vs Numba with `poetry run pytest-benchmark compare 0003 0005`:

illico/asymptotic_wilcoxon.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
from typing import Literal
32

43
import anndata as ad
@@ -12,6 +11,7 @@
1211

1312
from illico.utils.compile import _precompile
1413
from illico.utils.groups import GroupContainer, encode_and_count_groups
14+
from illico.utils.math import compute_batch_bounds
1515
from illico.utils.memory import log_memory_usage
1616
from illico.utils.ranking import check_indices_sorted_per_parcel
1717
from illico.utils.registry import (
@@ -254,24 +254,9 @@ def asymptotic_wilcoxon(
254254
rows = pd.Series(unique_raw_groups, name="pert", dtype=str)
255255
results = np.empty((len(rows), len(cols), 4), dtype=np.float64)
256256

257-
# Adapt batch size to leverage multithreading regarding the number of genes, if requested
258-
if n_genes < 256:
259-
batch_size = n_genes # No batching for small number of genes
260-
n_threads = 1 # No multithreading for small number of genes
261-
iterator = [[0, n_genes]]
262-
elif isinstance(batch_size, int):
263-
batch_size = min(batch_size, math.ceil(n_genes / n_threads))
264-
bounds = np.append(np.arange(0, n_genes, batch_size), n_genes)
265-
iterator = list(zip(bounds[:-1], bounds[1:]))
266-
elif batch_size == "auto":
267-
n_dispatches = max(int(n_genes / 256 / n_threads), 1) # Aim for approximately 256 genes per chunk
268-
splits = np.array_split(np.arange(n_genes + 1), indices_or_sections=n_threads * n_dispatches)
269-
iterator = [[split[0], split[-1] + 1] for split in splits]
270-
iterator[-1][-1] = n_genes # Ensure the last upper bound is exactly n_genes
271-
batch_size = int(np.ceil(n_genes / (n_dispatches * n_threads)))
272-
else:
273-
raise ValueError(f"Invalid batch_size value: {batch_size}. Must be 'auto' or an integer.")
274-
logger.trace(f"Using batch size of {batch_size} for {n_threads} threads and {n_genes} genes.")
257+
# Compute the batch bounds for each thread
258+
iterator, batch_size = compute_batch_bounds(n_genes, batch_size, n_threads)
259+
logger.trace(f"Processing {n_genes} genes through {len(iterator)} batches with {n_threads} threads.")
275260

276261
# Compute estimated mem footprint
277262
_ = log_memory_usage(data_handler, group_container, batch_size, n_threads)

illico/utils/math.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import math
44
import warnings
5-
from typing import Literal
5+
from typing import List, Literal, Tuple
66

77
import numpy as np
88
from numba import njit
@@ -281,3 +281,49 @@ def chunk_and_fortranize(X: np.ndarray, chunk_lb: int, chunk_ub: int, indices: n
281281
for j in range(0, chunk_ub - chunk_lb):
282282
chunk[i, j] = X[i, chunk_lb + j]
283283
return chunk
284+
285+
286+
def compute_batch_bounds(n_genes: int, batch_size: Literal["auto"] | int, n_threads: int) -> List[Tuple[int, int]]:
287+
"""Computes ideal batch bounds for processing genes in batches.
288+
This function ensures no worker is starving. This could happen if we have 8 workers but 9 batches to allocate.
289+
In this case, because each batch takes the same time to be processed, all but one workers will be idle waiting for one worker to process the last batch.
290+
291+
Args:
292+
n_genes (int): Total number of genes
293+
batch_size (Literal["auto"] | int): Batch size, or "auto" to compute ideal batch size.
294+
n_threads (int): Number of threads to use.
295+
Returns:
296+
List[Tuple[int, int]]: List of (lower_bound, upper_bound) for each batch. Upper bound is excluding, following slicing conventions.
297+
"""
298+
# No batching nor multithreading for small inputs
299+
if n_genes < n_threads or n_genes < 256:
300+
batch_size = n_genes
301+
# n_threads = 1
302+
batch_size = n_genes
303+
bounds_iterator = [[0, n_genes]]
304+
elif isinstance(batch_size, int):
305+
# batch_size = min(batch_size, math.ceil(n_genes / n_threads))
306+
bounds = list(range(0, n_genes + 1, batch_size))
307+
if bounds[-1] != n_genes:
308+
bounds.append(n_genes)
309+
bounds_iterator = list(zip(bounds[:-1], bounds[1:]))
310+
elif batch_size == "auto":
311+
target_batch_size = 256
312+
min_batches = (n_genes + target_batch_size - 1) // target_batch_size
313+
num_batches = ((min_batches + n_threads - 1) // n_threads) * n_threads
314+
base_size = n_genes // num_batches
315+
remainder = n_genes % num_batches
316+
bounds_iterator = []
317+
start = 0
318+
for i in range(num_batches):
319+
end = start + base_size + (1 if i < remainder else 0)
320+
bounds_iterator.append((start, end))
321+
start = end
322+
# Append the last gene as the right bound is excluding
323+
if bounds_iterator[-1][1] != n_genes:
324+
bounds_iterator[-1][1] = n_genes
325+
batch_size = base_size
326+
else:
327+
raise ValueError(f"Invalid batch_size value: {batch_size}. Must be 'auto' or an integer.")
328+
329+
return bounds_iterator, batch_size

tests/test_asymptotic_wilcoxon.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ def test_asymptotic_wilcoxon_auto_batchsize(eager_rand_adata):
494494
bigger_eager_rand_adata = ad.concat(
495495
[eager_rand_adata] * int(math.ceil(target_n_cols / eager_rand_adata.n_vars)), axis=1
496496
)
497+
bigger_eager_rand_adata.var_names_make_unique()
497498
bigger_eager_rand_adata.obs = eager_rand_adata.obs.copy()
498499
asy_results = asymptotic_wilcoxon(
499500
adata=bigger_eager_rand_adata,

0 commit comments

Comments
 (0)