Skip to content

Commit 13f9ab2

Browse files
committed
Functional groups and exclude for Numba backend
1 parent b76e910 commit 13f9ab2

10 files changed

Lines changed: 293 additions & 98 deletions

File tree

illico/asymptotic_wilcoxon.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ def asymptotic_wilcoxon(
218218
tie_correct: bool = True,
219219
exp_post_agg: bool = False,
220220
layer: str | None = None,
221+
groups: list[str] | None = None,
222+
exclude_from_ovr: list[str] | None = None,
221223
precompile: bool = True,
222224
use_rust: bool = True,
223225
return_as_scanpy: bool = False,
@@ -264,6 +266,15 @@ def asymptotic_wilcoxon(
264266
Note that `scanpy.rank_genes_groups` assumes the data to be log1p, and exponentiates post aggregation by default.
265267
layer : str or None, default=None
266268
Layer in `adata.layers` to use for the data. If `None`, uses `adata.X`.
269+
groups : list of str or None, default=None
270+
Subset of groups to test. If `None`, tests all groups. This arguments serves the same purpose as scanpy's `groups` argument in `rank_genes_groups`.
271+
It is used to filter which groups to compare against the reference in the OVO scenario, or which groups to compare against the rest in the OVR scenario.
272+
Note that in the OVR scenario, each comparison still happens against the entirety of the other groups, not just the ones listed in this argument.
273+
Note that in the OVO scenario, the reference group is automatically added.
274+
exclude_from_ovr : list of str or None, default=None
275+
Subset of groups to exclude from the rest group in the OVR scenario. This argument is ignored in the OVO scenario.
276+
This can be useful if, for instance, one of the groups is corrupted and contains meaningless data, and we don't want it to be part of the comparisons in the OVR scenario.
277+
TODO: add warning about what values are okay or not okay taking interaction with `groups` into account.
267278
precompile : bool, default=True
268279
Whether to precompile necessary functions for performance. It is recommended to set this to `True`.
269280
use_rust : bool, default=True
@@ -376,21 +387,24 @@ def asymptotic_wilcoxon(
376387

377388
# Process the groups information
378389
unique_raw_groups, group_container = encode_and_count_groups(
379-
groups=adata.obs[group_keys].values, ref_group=reference
390+
groups=adata.obs[group_keys].values,
391+
ref_group=reference,
392+
group_subset=groups,
393+
exclude=exclude_from_ovr,
380394
)
381395
logger.info(
382-
f"Found {group_container.counts.size} unique groups (min size: {group_container.counts.min()} cells; "
396+
f"Found {group_container.counts.size} unique groups ({group_container.selected_group_ids.size} valid ones) (min size: {group_container.counts.min()} cells; "
383397
f"max size: {group_container.counts.max()} cells), with reference group: {reference}"
384398
)
385399
_, n_genes_total = X.shape
386400

387401
# Allocate the results dataframes
388402
cols = pd.Series(adata.var_names, name="feature", dtype=str)
389-
rows = pd.Series(unique_raw_groups, name="pert", dtype=str)
403+
rows = pd.Series(unique_raw_groups[group_container.selected_group_ids], name="pert", dtype=str)
390404
results = np.empty((len(rows), len(cols), 4), dtype=np.float64)
391405

392406
# Go through all the possible combinations
393-
n_tests = n_genes_total * group_container.counts.size
407+
n_tests = n_genes_total * group_container.selected_group_ids.size
394408
logger.trace(f"Performing a total of {n_tests:,d} tests.")
395409
with Parallel(n_threads, prefer="threads", return_as="generator_unordered") as pool:
396410
with tqdm(total=n_tests, smoothing=0.0, unit="it", unit_scale=True, unit_divisor=1000) as pbar:
@@ -427,7 +441,7 @@ def asymptotic_wilcoxon(
427441

428442
# Process chunks of columns one by one
429443
for lb, ub in pool(all_purpose_operator(data_handler, lb, ub, group_container, is_log1p, use_continuity, alternative, tie_correct, exp_post_agg, use_rust, results) for lb, ub in iterator): # fmt: skip
430-
pbar.update(group_container.counts.size * (ub - lb))
444+
pbar.update(group_container.selected_group_ids.size * (ub - lb))
431445

432446
if not return_as_scanpy:
433447
if n_genes is not None:
@@ -444,7 +458,7 @@ def asymptotic_wilcoxon(
444458
# Return a dict formatted for Scanpy's rank_genes_groups results
445459
results = format_illico_results_for_scanpy(
446460
adata=adata,
447-
unique_groups=unique_raw_groups,
461+
unique_groups=unique_raw_groups[group_container.selected_group_ids],
448462
reference=reference,
449463
group_keys=group_keys,
450464
layer=layer,

illico/ovo/dense_ovo.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from numba import njit
55

66
from illico.utils.groups import GroupContainer
7-
from illico.utils.math import chunk_and_fortranize, compute_pval, dense_fold_change
7+
from illico.utils.math import (
8+
chunk_and_fortranize,
9+
compute_pval,
10+
dense_fold_change,
11+
fancy_indexing_axis0,
12+
)
813
from illico.utils.ranking import (
914
_sort_along_axis_inplace,
1015
rank_sum_and_ties_from_sorted,
@@ -118,29 +123,32 @@ def dense_ovo_mwu_kernel_over_contiguous_col_chunk(
118123
ref_chunk = chunk_and_fortranize(X, chunk_lb, chunk_ub, ref_indices)
119124
_sort_along_axis_inplace(ref_chunk, axis=0)
120125

121-
pvalues = np.empty((n_groups, chunk_ub - chunk_lb), dtype=np.float64)
122-
zscores = np.empty((n_groups, chunk_ub - chunk_lb), dtype=np.float64)
123-
statistics = np.empty((n_groups, chunk_ub - chunk_lb), dtype=np.float64)
124-
for group_id in range(n_groups):
126+
n_selected_groups = grpc.selected_group_ids.size
127+
pvalues = np.empty((n_selected_groups, chunk_ub - chunk_lb), dtype=np.float64)
128+
zscores = np.empty((n_selected_groups, chunk_ub - chunk_lb), dtype=np.float64)
129+
statistics = np.empty((n_selected_groups, chunk_ub - chunk_lb), dtype=np.float64)
130+
for k, group_id in enumerate(grpc.selected_group_ids):
125131
if group_id == grpc.encoded_ref_group:
126-
pvalues[group_id, :] = 1.0
127-
zscores[group_id, :] = 0.0
128-
statistics[group_id, :] = -1.0
132+
pvalues[k, :] = 1.0
133+
zscores[k, :] = 0.0
134+
statistics[k, :] = -1.0
129135
continue
130136
tgt_indices = grpc.indices[grpc.indptr[group_id] : grpc.indptr[group_id + 1]]
131137
# tgt_chunk = np.asfortranarray(chunk[tgt_indices, :])
132138
tgt_chunk = chunk_and_fortranize(X, chunk_lb, chunk_ub, tgt_indices)
133139
_sort_along_axis_inplace(tgt_chunk, axis=0)
134140

135-
pvalues[group_id], statistics[group_id], zscores[group_id] = dense_ovo_mwu_kernel(
141+
pvalues[k], statistics[k], zscores[k] = dense_ovo_mwu_kernel(
136142
sorted_ref_data=ref_chunk,
137143
sorted_tgt_data=tgt_chunk,
138144
use_continuity=use_continuity,
139145
tie_correct=tie_correct,
140146
alternative=alternative,
141147
)
142148

143-
# Compute fold change
149+
# Compute fold change on all groups, but return it only for the selected groups
144150
fc = dense_fold_change(chunk, grpc, is_log1p=is_log1p, exp_post_agg=exp_post_agg)
151+
if n_selected_groups < n_groups:
152+
fc = fancy_indexing_axis0(fc, grpc.selected_group_ids)
145153

146154
return pvalues, statistics, zscores, fc

illico/ovo/sparse_ovo.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from numba import njit
55

66
from illico.utils.groups import GroupContainer
7-
from illico.utils.math import compute_pval, diff, fold_change_from_summed_expr
7+
from illico.utils.math import (
8+
compute_pval,
9+
diff,
10+
fancy_indexing_axis0,
11+
fold_change_from_summed_expr,
12+
)
813
from illico.utils.ranking import (
914
_sort_csc_columns_inplace,
1015
rank_sum_and_ties_from_sorted,
@@ -171,15 +176,16 @@ def csc_ovo_mwu_kernel_over_contiguous_col_chunk(
171176
agg_counts = np.empty((n_groups, chunk_ub - chunk_lb), dtype=np.float64)
172177

173178
# Now go through all the groups one by one
174-
pvalues = np.empty((n_groups, csc_X_ref.shape[1]), dtype=np.float64)
175-
zscores = np.empty((n_groups, csc_X_ref.shape[1]), dtype=np.float64)
176-
statistics = np.empty((n_groups, csc_X_ref.shape[1]), dtype=np.float64)
177-
for group_id in range(group_indptr.size - 1):
179+
n_selected_groups = grpc.selected_group_ids.size
180+
pvalues = np.empty((n_selected_groups, csc_X_ref.shape[1]), dtype=np.float64)
181+
zscores = np.empty((n_selected_groups, csc_X_ref.shape[1]), dtype=np.float64)
182+
statistics = np.empty((n_selected_groups, csc_X_ref.shape[1]), dtype=np.float64)
183+
for k, group_id in enumerate(grpc.selected_group_ids):
178184
if group_id == ref_group_id:
179-
pvalues[group_id, :] = 1.0
180-
zscores[group_id, :] = 0.0
181-
statistics[group_id, :] = -1.0
182-
agg_counts[ref_group_id, :] = csc_sum_axis0(csc_X_ref, expm1=is_log1p & (not exp_post_agg))
185+
pvalues[k, :] = 1.0
186+
zscores[k, :] = 0.0
187+
statistics[k, :] = -1.0
188+
agg_counts[k, :] = csc_sum_axis0(csc_X_ref, expm1=is_log1p & (not exp_post_agg))
183189
continue
184190

185191
# Chunk
@@ -197,11 +203,13 @@ def csc_ovo_mwu_kernel_over_contiguous_col_chunk(
197203
tie_correct=tie_correct,
198204
alternative=alternative,
199205
)
200-
pvalues[group_id, :] = pvalue
201-
statistics[group_id, :] = statistic
202-
zscores[group_id, :] = zscore
206+
pvalues[k, :] = pvalue
207+
statistics[k, :] = statistic
208+
zscores[k, :] = zscore
203209

204210
fold_change = fold_change_from_summed_expr(agg_counts, grpc, exp_post_agg=exp_post_agg & is_log1p)
211+
if n_selected_groups < n_groups:
212+
fold_change = fancy_indexing_axis0(fold_change, grpc.selected_group_ids)
205213

206214
return pvalues, statistics, zscores, fold_change
207215

@@ -264,19 +272,18 @@ def csr_ovo_mwu_kernel_over_contiguous_col_chunk(
264272
# Sort
265273
_sort_csc_columns_inplace(csc_matrix=csc_X_ref)
266274

267-
# Initalize aggregated matrix to compute fold change later on
268-
agg_counts = np.empty((n_groups, chunk_ub - chunk_lb), dtype=np.float64)
269-
270275
# Now go through all the groups one by one
271-
pvalues = np.empty((n_groups, csc_X_ref.shape[1]), dtype=np.float64)
272-
zscores = np.empty((n_groups, csc_X_ref.shape[1]), dtype=np.float64)
273-
statistics = np.empty((n_groups, csc_X_ref.shape[1]), dtype=np.float64)
274-
for group_id in range(group_indptr.size - 1):
276+
agg_counts = np.empty((n_groups, chunk_ub - chunk_lb), dtype=np.float64)
277+
n_selected_groups = grpc.selected_group_ids.size
278+
pvalues = np.empty((n_selected_groups, csc_X_ref.shape[1]), dtype=np.float64)
279+
zscores = np.empty((n_selected_groups, csc_X_ref.shape[1]), dtype=np.float64)
280+
statistics = np.empty((n_selected_groups, csc_X_ref.shape[1]), dtype=np.float64)
281+
for k, group_id in enumerate(grpc.selected_group_ids):
275282
if group_id == ref_group_id:
276-
pvalues[group_id, :] = 1.0
277-
zscores[group_id, :] = 0.0
278-
statistics[group_id, :] = -1.0
279-
agg_counts[ref_group_id, :] = csc_sum_axis0(csc_X_ref, expm1=is_log1p & (not exp_post_agg))
283+
pvalues[k, :] = 1.0
284+
zscores[k, :] = 0.0
285+
statistics[k, :] = -1.0
286+
agg_counts[k, :] = csc_sum_axis0(csc_X_ref, expm1=is_log1p & (not exp_post_agg))
280287
continue
281288

282289
# Chunk
@@ -294,10 +301,13 @@ def csr_ovo_mwu_kernel_over_contiguous_col_chunk(
294301
tie_correct=tie_correct,
295302
alternative=alternative,
296303
)
297-
pvalues[group_id, :] = pvalue
298-
statistics[group_id, :] = statistic
299-
zscores[group_id, :] = zscore
304+
pvalues[k, :] = pvalue
305+
statistics[k, :] = statistic
306+
zscores[k, :] = zscore
300307

308+
# Compute fold change for all groups, but return only the groups of interest
301309
fold_change = fold_change_from_summed_expr(agg_counts, grpc, exp_post_agg=exp_post_agg & is_log1p)
310+
if n_selected_groups < n_groups:
311+
fold_change = fancy_indexing_axis0(fold_change, grpc.selected_group_ids)
302312

303313
return pvalues, statistics, zscores, fold_change

illico/ovr/dense_ovr.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
from numba import njit
77

88
from illico.utils.groups import GroupContainer
9-
from illico.utils.math import chunk_and_fortranize, compute_pval, dense_fold_change
9+
from illico.utils.math import (
10+
_add_at_vec,
11+
chunk_and_fortranize,
12+
compute_pval,
13+
fancy_indexing_axis0,
14+
fold_change_from_summed_expr,
15+
)
1016
from illico.utils.ranking import _accumulate_group_ranksums_from_argsort
1117
from illico.utils.registry import KernelDataFormat, Test, nb_dispatcher_registry
1218

@@ -46,14 +52,17 @@ def dense_ovr_mwu_kernel_over_contiguous_col_chunk(
4652
4753
"""
4854
# Convert to F-order for faster column access and sorting later
49-
chunk = chunk_and_fortranize(X, chunk_lb, chunk_ub, None)
55+
chunk = chunk_and_fortranize(X, chunk_lb, chunk_ub, grpc.ovr_inclusion_indices)
5056

5157
# Get ranks and tie sums
5258
tie_sum = np.empty(chunk.shape[1], dtype=np.float64)
5359
ranksums = np.zeros(shape=(grpc.counts.size, chunk.shape[1]), dtype=np.float64)
60+
included_groups_indicator = grpc.encoded_groups[grpc.ovr_inclusion_indices]
5461
for j in range(chunk.shape[1]):
5562
idxs = np.argsort(chunk[:, j])
56-
col_tie_sum, _ = _accumulate_group_ranksums_from_argsort(chunk[:, j], idxs, grpc.encoded_groups, ranksums[:, j])
63+
col_tie_sum, _ = _accumulate_group_ranksums_from_argsort(
64+
chunk[:, j], idxs, included_groups_indicator, ranksums[:, j]
65+
)
5766
tie_sum[j] = col_tie_sum
5867

5968
# Compute U stats
@@ -63,22 +72,39 @@ def dense_ovr_mwu_kernel_over_contiguous_col_chunk(
6372
statistics = ranksums - n_tgt * (n_tgt + 1) / 2
6473
mu = n_ref * n_tgt / 2.0
6574
# Compute pvals
66-
pvals = np.empty(shape=(grpc.counts.size, chunk.shape[1]), dtype=np.float64)
67-
zscores = np.empty(shape=(grpc.counts.size, chunk.shape[1]), dtype=np.float64)
75+
n_selected_groups = grpc.selected_group_ids.size
76+
pvals = np.empty(shape=(n_selected_groups, chunk.shape[1]), dtype=np.float64)
77+
zscores = np.empty(shape=(n_selected_groups, chunk.shape[1]), dtype=np.float64)
6878
for j in range(chunk.shape[1]):
69-
for k in range(grpc.counts.size):
79+
for k, grp_id in enumerate(grpc.selected_group_ids):
7080
pvals[k, j], zscores[k, j] = compute_pval(
71-
n_ref=n_ref[k, 0],
72-
n_tgt=n_tgt[k, 0],
81+
n_ref=n_ref[grp_id, 0],
82+
n_tgt=n_tgt[grp_id, 0],
7383
n=n,
7484
tie_sum=tie_sum[j] if tie_correct else 0.0,
75-
U=statistics[k, j],
76-
mu=mu[k, 0],
85+
U=statistics[grp_id, j],
86+
mu=mu[grp_id, 0],
7787
contin_corr=0.5 if use_continuity else 0.0,
7888
alternative=alternative,
7989
)
8090

8191
# Get fold change
82-
fold_change = dense_fold_change(chunk, grpc=grpc, is_log1p=is_log1p, exp_post_agg=exp_post_agg)
92+
# Note: it would be a bit cumbersome to have dense_fold_change handle itself all the shennanigans
93+
# groups and subsetting. I find clearer to have it here.
94+
# TODO: actually idk, bc I ended up doing it in the sparse path.
95+
group_agg_counts = np.zeros(shape=(grpc.counts.size, X.shape[1]), dtype=np.float64)
96+
# Sum expressions per group
97+
if is_log1p and not exp_post_agg:
98+
_add_at_vec(group_agg_counts, grpc.encoded_groups[grpc.ovr_inclusion_indices], np.expm1(chunk))
99+
else:
100+
_add_at_vec(group_agg_counts, grpc.encoded_groups[grpc.ovr_inclusion_indices], chunk)
101+
fold_change = fold_change_from_summed_expr(
102+
group_agg_counts, grpc, exp_post_agg=exp_post_agg & is_log1p, sum_over_selected_groups_only=True
103+
)
104+
105+
# Now filter on the groups to return, if needed
106+
if n_selected_groups < grpc.counts.size:
107+
fold_change = fancy_indexing_axis0(fold_change, grpc.selected_group_ids)
108+
statistics = fancy_indexing_axis0(statistics, grpc.selected_group_ids)
83109

84110
return pvals, statistics, zscores, fold_change

0 commit comments

Comments
 (0)