Skip to content

Commit c4adf47

Browse files
authored
[ENH] make ALE/FWE faster (#999)
* make ALE faster jit the approximate-null histogram update change monte carlo/FWE permutation to pass around precomputed ijk voxel indices instead of doing xyz->ijk each iteration * change the null-approximate calculation: �[200~kept the simple per-study loop and removed the slower “all-in-one” compiled histogram path replaced int(np.floor(...)) with direct truncation for nonnegative binning in _study_ma_histogram and _update_ale_histogram normalized study histograms with a precomputed reciprocal instead of exp_hist.sum() stopped forcing extra per-study float64 and astype(...) copies before histogram merging~ * style fix and test * fix style * fix style * fix outdated comments and add edge case tests * fix style
1 parent bdc6a0b commit c4adf47

5 files changed

Lines changed: 215 additions & 50 deletions

File tree

nimare/meta/cbma/ale.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import pandas as pd
1111
from joblib import Memory, Parallel, delayed
12+
from numba import jit
1213
from scipy import ndimage
1314
from scipy import sparse as sp_sparse
1415
from tqdm.auto import tqdm
@@ -66,6 +67,47 @@ def _compute_ale_summarystat(ma_values):
6667
raise ValueError(f"Unsupported data type '{type(ma_values)}'")
6768

6869

70+
@jit(nopython=True, cache=True)
71+
def _study_ma_histogram(study_ma_values, n_zero_voxels, mask_voxel_recip, inv_step_size, n_bins):
72+
"""Bin one study's nonzero ALE values onto the fixed approximate-null grid."""
73+
exp_hist = np.zeros(n_bins, dtype=np.float64)
74+
for i_val in range(study_ma_values.shape[0]):
75+
idx = int(study_ma_values[i_val] * inv_step_size)
76+
if idx < 0:
77+
idx = 0
78+
elif idx >= n_bins:
79+
idx = n_bins - 1
80+
exp_hist[idx] += 1.0
81+
82+
exp_hist[0] += n_zero_voxels
83+
exp_hist *= mask_voxel_recip
84+
return exp_hist
85+
86+
87+
@jit(nopython=True, cache=True)
88+
def _update_ale_histogram(
89+
ale_idx, ale_probs, exp_idx, exp_probs, bin_centers, inv_step_size, n_bins, out
90+
):
91+
"""Combine two nonzero ALE histograms using a reusable output buffer."""
92+
for i_bin in range(n_bins):
93+
out[i_bin] = 0.0
94+
95+
for i_exp in range(exp_idx.shape[0]):
96+
exp_center = bin_centers[exp_idx[i_exp]]
97+
exp_prob = exp_probs[i_exp]
98+
exp_one_minus = 1.0 - exp_center
99+
for i_ale in range(ale_idx.shape[0]):
100+
score = 1.0 - exp_one_minus * (1.0 - bin_centers[ale_idx[i_ale]])
101+
score_idx = int(score * inv_step_size)
102+
if score_idx < 0:
103+
score_idx = 0
104+
elif score_idx >= n_bins:
105+
score_idx = n_bins - 1
106+
out[score_idx] += exp_prob * ale_probs[i_ale]
107+
108+
return out
109+
110+
69111
def _collect_masked_ma_maps(estimator, coords_key="coordinates", maps_key="ma_maps"):
70112
"""Collect ALE-family MA maps in masked CSR form."""
71113
estimator._study_max_ma_values = None
@@ -345,18 +387,18 @@ def _compute_null_approximate(self, ma_maps):
345387

346388
assert "histogram_bins" in self.null_distributions_.keys()
347389

348-
# Derive bin edges from histogram bin centers for numpy histogram function
349-
bin_centers = self.null_distributions_["histogram_bins"]
390+
# Reuse the fixed histogram grid derived earlier in _determine_histogram_bins.
391+
bin_centers = self.null_distributions_["histogram_bins"].astype(np.float64, copy=False)
350392
step_size = bin_centers[1] - bin_centers[0]
351393
inv_step_size = 1 / step_size
352-
bin_edges = bin_centers - (step_size / 2)
353-
bin_edges = np.append(bin_centers, bin_centers[-1] + step_size)
354-
394+
n_bins = bin_centers.shape[0]
395+
mask_voxel_recip = 1.0 / self.__n_mask_voxels
355396
n_exp = ma_maps.shape[0]
356397
data = ma_maps.data
357398
indptr = ma_maps.indptr
358399

359400
ale_hist = None
401+
tmp_hist = np.zeros(n_bins, dtype=np.float64)
360402
for exp_idx in range(n_exp):
361403
start = indptr[exp_idx]
362404
end = indptr[exp_idx + 1]
@@ -365,32 +407,31 @@ def _compute_null_approximate(self, ma_maps):
365407
n_nonzero_voxels = study_ma_values.shape[0]
366408
n_zero_voxels = self.__n_mask_voxels - n_nonzero_voxels
367409

368-
exp_hist = np.histogram(study_ma_values, bins=bin_edges, density=False)[0].astype(
369-
float
410+
exp_hist = _study_ma_histogram(
411+
study_ma_values,
412+
n_zero_voxels,
413+
mask_voxel_recip,
414+
inv_step_size,
415+
n_bins,
370416
)
371-
exp_hist[0] += n_zero_voxels
372-
exp_hist /= exp_hist.sum()
373417

374418
if ale_hist is None:
375419
ale_hist = exp_hist.copy()
376420
continue
377421

378-
# Find histogram bins with nonzero values for each histogram.
379422
ale_idx = np.where(ale_hist > 0)[0]
380-
exp_idx = np.where(exp_hist > 0)[0]
381-
382-
# Compute output MA values, ale_hist indices, and probabilities
383-
ale_scores = (
384-
1 - np.outer((1 - bin_centers[exp_idx]), (1 - bin_centers[ale_idx])).ravel()
423+
exp_hist_idx = np.where(exp_hist > 0)[0]
424+
_update_ale_histogram(
425+
ale_idx,
426+
ale_hist[ale_idx],
427+
exp_hist_idx,
428+
exp_hist[exp_hist_idx],
429+
bin_centers,
430+
inv_step_size,
431+
n_bins,
432+
tmp_hist,
385433
)
386-
score_idx = np.floor(ale_scores * inv_step_size).astype(int)
387-
probabilities = np.outer(exp_hist[exp_idx], ale_hist[ale_idx]).ravel()
388-
389-
# Reset histogram and set probabilities.
390-
# Use at() instead of setting values directly (ale_hist[score_idx] = probabilities)
391-
# because there can be redundant values in score_idx.
392-
ale_hist = np.zeros(ale_hist.shape)
393-
np.add.at(ale_hist, score_idx, probabilities)
434+
ale_hist, tmp_hist = tmp_hist, ale_hist
394435

395436
self.null_distributions_["histweights_corr-none_method-approximate"] = ale_hist
396437

nimare/meta/cbma/base.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
_mask_img_to_bool,
2929
get_masker,
3030
mm2vox,
31-
vox2mm,
3231
)
3332

3433
LGR = logging.getLogger(__name__)
@@ -590,7 +589,7 @@ def _compute_null_reduced_montecarlo(self, ma_maps, n_iters=5000):
590589
null_dist = self._compute_summarystat(iter_ma_values)
591590
self.null_distributions_["values_corr-none_method-reducedMontecarlo"] = null_dist
592591

593-
def _compute_null_montecarlo_permutation(self, iter_xyz, iter_df, bin_edges=None):
592+
def _compute_null_montecarlo_permutation(self, iter_ijk, iter_df, bin_edges=None):
594593
"""Run a single Monte Carlo permutation of a dataset.
595594
596595
Does the shared work between uncorrected stat-to-p conversion and vFWE.
@@ -610,8 +609,8 @@ def _compute_null_montecarlo_permutation(self, iter_xyz, iter_df, bin_edges=None
610609
# be safe.
611610
iter_df = iter_df.copy()
612611

613-
iter_xyz = np.squeeze(iter_xyz)
614-
iter_df[["x", "y", "z"]] = iter_xyz
612+
iter_ijk = np.squeeze(iter_ijk)
613+
iter_df[["i", "j", "k"]] = iter_ijk
615614

616615
iter_ma_maps = self.kernel_transformer.transform(
617616
iter_df, masker=self.masker, return_type="sparse"
@@ -655,9 +654,8 @@ def _compute_null_montecarlo(self, n_iters, n_cores):
655654
size=(self.inputs_["coordinates"].shape[0], n_iters),
656655
)
657656
rand_ijk = null_ijk[rand_idx, :]
658-
rand_xyz = vox2mm(rand_ijk, self.masker.mask_img.affine)
659-
iter_xyzs = np.split(rand_xyz, rand_xyz.shape[1], axis=1)
660-
iter_df = self.inputs_["coordinates"].copy()
657+
iter_ijks = np.split(rand_ijk, rand_ijk.shape[1], axis=1)
658+
iter_df = self.inputs_["coordinates"].drop(columns=["x", "y", "z"], errors="ignore").copy()
661659
parallel_kwargs = {"return_as": "generator", "n_jobs": n_cores}
662660
if getattr(self, "_permutation_parallel_backend", None) is not None:
663661
parallel_kwargs["backend"] = self._permutation_parallel_backend
@@ -669,7 +667,7 @@ def _compute_null_montecarlo(self, n_iters, n_cores):
669667

670668
perm_histograms = Parallel(**parallel_kwargs)(
671669
delayed(self._compute_null_montecarlo_permutation)(
672-
iter_xyzs[i_iter],
670+
iter_ijks[i_iter],
673671
iter_df=iter_df,
674672
bin_edges=bin_edges,
675673
)
@@ -693,7 +691,7 @@ def _compute_null_montecarlo(self, n_iters, n_cores):
693691

694692
def _correct_fwe_montecarlo_permutation(
695693
self,
696-
iter_xyz,
694+
iter_ijk,
697695
iter_df,
698696
conn,
699697
voxel_thresh,
@@ -705,9 +703,9 @@ def _correct_fwe_montecarlo_permutation(
705703
706704
Parameters
707705
----------
708-
iter_xyz : :obj:`numpy.ndarray` of shape (C, 3)
709-
The permuted coordinates. One row for each peak.
710-
Columns correspond to x, y, and z coordinates.
706+
iter_ijk : :obj:`numpy.ndarray` of shape (C, 3)
707+
The permuted matrix indices. One row for each peak.
708+
Columns correspond to i, j, and k coordinates.
711709
iter_df : :obj:`pandas.DataFrame`
712710
The coordinates DataFrame, to be filled with the permuted coordinates in ``iter_xyz``
713711
before permutation MA maps are generated.
@@ -727,8 +725,8 @@ def _correct_fwe_montecarlo_permutation(
727725
"""
728726
iter_df = iter_df.copy()
729727

730-
iter_xyz = np.squeeze(iter_xyz)
731-
iter_df[["x", "y", "z"]] = iter_xyz
728+
iter_ijk = np.squeeze(iter_ijk)
729+
iter_df[["i", "j", "k"]] = iter_ijk
732730

733731
iter_ma_maps = self.kernel_transformer.transform(
734732
iter_df, masker=self.masker, return_type="sparse"
@@ -864,23 +862,22 @@ def correct_fwe_montecarlo(
864862
"Running permutations from scratch."
865863
)
866864

867-
null_xyz = vox2mm(
868-
np.vstack(np.where(_mask_img_to_bool(self.masker.mask_img))).T,
869-
self.masker.mask_img.affine,
870-
)
865+
null_ijk = np.vstack(np.where(_mask_img_to_bool(self.masker.mask_img))).T
871866

872867
n_cores = _check_ncores(n_cores)
873868

874869
# Identify summary statistic corresponding to intensity threshold
875870
ss_thresh = self._p_to_summarystat(voxel_thresh)
876871

877872
rand_idx = np.random.choice(
878-
null_xyz.shape[0],
873+
null_ijk.shape[0],
879874
size=(self.inputs_["coordinates"].shape[0], n_iters),
880875
)
881-
rand_xyz = null_xyz[rand_idx, :]
882-
iter_xyzs = np.split(rand_xyz, rand_xyz.shape[1], axis=1)
883-
iter_df = self.inputs_["coordinates"].copy()
876+
rand_ijk = null_ijk[rand_idx, :]
877+
iter_ijks = np.split(rand_ijk, rand_ijk.shape[1], axis=1)
878+
iter_df = (
879+
self.inputs_["coordinates"].drop(columns=["x", "y", "z"], errors="ignore").copy()
880+
)
884881
parallel_kwargs = {"return_as": "generator", "n_jobs": n_cores}
885882
if getattr(self, "_permutation_parallel_backend", None) is not None:
886883
parallel_kwargs["backend"] = self._permutation_parallel_backend
@@ -890,7 +887,7 @@ def correct_fwe_montecarlo(
890887

891888
perm_results = Parallel(**parallel_kwargs)(
892889
delayed(self._correct_fwe_montecarlo_permutation)(
893-
iter_xyzs[i_iter],
890+
iter_ijks[i_iter],
894891
iter_df=iter_df,
895892
conn=conn,
896893
voxel_thresh=ss_thresh,

nimare/meta/kernel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def transform(self, dataset, masker=None, return_type="image"):
151151
dataset : :obj:`~nimare.dataset.Dataset`, :obj:`~nimare.nimads.Studyset`, \
152152
or :obj:`pandas.DataFrame`
153153
Collection for which to make images. Can be a DataFrame if necessary.
154+
DataFrame inputs may provide precomputed matrix indices in ``i``, ``j``, and ``k``.
155+
When those columns are present, they are used directly and ``x``, ``y``, and ``z``
156+
are ignored.
154157
masker : img_like or None, optional
155158
Mask to apply to MA maps. Required if ``dataset`` is a DataFrame.
156159
If None, the input collection's masker attribute will be used.
@@ -195,11 +198,12 @@ def transform(self, dataset, masker=None, return_type="image"):
195198
masker is not None
196199
), "Argument 'masker' must be provided if dataset is a DataFrame."
197200
mask = masker.mask_img
198-
coordinates = dataset
201+
coordinates = dataset.copy()
199202

200-
# Calculate IJK. Must assume that the masker is in same space,
201-
# but has different affine, from original IJK.
202-
coordinates[["i", "j", "k"]] = mm2vox(dataset[["x", "y", "z"]], mask.affine)
203+
if not {"i", "j", "k"}.issubset(coordinates.columns):
204+
# Calculate IJK. Must assume that the masker is in same space,
205+
# but has different affine, from original IJK.
206+
coordinates[["i", "j", "k"]] = mm2vox(dataset[["x", "y", "z"]], mask.affine)
203207
else:
204208
if not isinstance(dataset, Dataset):
205209
dataset = normalize_collection(dataset)

nimare/tests/test_meta_ale.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,38 @@ def _prepare_ale_inputs(dataset, kernel_transformer=None):
112112
return meta
113113

114114

115+
def _study_ma_histogram_reference(
116+
study_ma_values, n_zero_voxels, mask_voxel_recip, inv_step_size, n_bins
117+
):
118+
"""Reference implementation for ALE study-histogram binning."""
119+
exp_hist = np.zeros(n_bins, dtype=np.float64)
120+
for value in study_ma_values:
121+
idx = int(np.floor(value * inv_step_size))
122+
idx = min(max(idx, 0), n_bins - 1)
123+
exp_hist[idx] += 1.0
124+
125+
exp_hist[0] += n_zero_voxels
126+
exp_hist *= mask_voxel_recip
127+
return exp_hist
128+
129+
130+
def _update_ale_histogram_reference(
131+
ale_idx, ale_probs, exp_idx, exp_probs, bin_centers, inv_step_size, n_bins
132+
):
133+
"""Reference implementation for ALE histogram updates."""
134+
out = np.zeros(n_bins, dtype=np.float64)
135+
for i_exp in range(exp_idx.shape[0]):
136+
exp_center = bin_centers[exp_idx[i_exp]]
137+
exp_prob = exp_probs[i_exp]
138+
exp_one_minus = 1.0 - exp_center
139+
for i_ale in range(ale_idx.shape[0]):
140+
score = 1.0 - exp_one_minus * (1.0 - bin_centers[ale_idx[i_ale]])
141+
score_idx = int(np.floor(score * inv_step_size))
142+
score_idx = min(max(score_idx, 0), n_bins - 1)
143+
out[score_idx] += exp_prob * ale_probs[i_ale]
144+
return out
145+
146+
115147
def test_ALE_missing_sample_sizes_raises_informative_error(testdata_cbma_full):
116148
"""Raise a helpful error listing ids when sample sizes are missing."""
117149
dset = copy.deepcopy(testdata_cbma_full)
@@ -374,6 +406,69 @@ def test_ALE_csr_approximate_null_matches_dense_reference():
374406
)
375407

376408

409+
def test_ALE_study_ma_histogram_edge_bins():
410+
"""Study histogram binning should match the legacy floor-based implementation at edges."""
411+
inv_step_size = 10.0
412+
n_bins = 11
413+
n_zero_voxels = 3
414+
mask_voxel_recip = 1.0 / (n_zero_voxels + 6)
415+
study_ma_values = np.array(
416+
[0.0, 0.099999999, 0.1, 0.199999999, 0.9, 0.999999999],
417+
dtype=np.float64,
418+
)
419+
420+
actual = ale._study_ma_histogram(
421+
study_ma_values,
422+
n_zero_voxels,
423+
mask_voxel_recip,
424+
inv_step_size,
425+
n_bins,
426+
)
427+
expected = _study_ma_histogram_reference(
428+
study_ma_values,
429+
n_zero_voxels,
430+
mask_voxel_recip,
431+
inv_step_size,
432+
n_bins,
433+
)
434+
435+
np.testing.assert_allclose(actual, expected)
436+
437+
438+
def test_ALE_update_histogram_edge_bins():
439+
"""Histogram updates should match the legacy floor-based implementation at bin edges."""
440+
bin_centers = np.linspace(0.0, 1.0, 11, dtype=np.float64)
441+
inv_step_size = 10.0
442+
n_bins = bin_centers.shape[0]
443+
ale_idx = np.array([0, 1, 9, 10], dtype=np.int64)
444+
exp_idx = np.array([0, 1, 9, 10], dtype=np.int64)
445+
ale_probs = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float64)
446+
exp_probs = np.array([0.4, 0.3, 0.2, 0.1], dtype=np.float64)
447+
out = np.empty(n_bins, dtype=np.float64)
448+
449+
actual = ale._update_ale_histogram(
450+
ale_idx,
451+
ale_probs,
452+
exp_idx,
453+
exp_probs,
454+
bin_centers,
455+
inv_step_size,
456+
n_bins,
457+
out,
458+
)
459+
expected = _update_ale_histogram_reference(
460+
ale_idx,
461+
ale_probs,
462+
exp_idx,
463+
exp_probs,
464+
bin_centers,
465+
inv_step_size,
466+
n_bins,
467+
)
468+
469+
np.testing.assert_allclose(actual, expected)
470+
471+
377472
@pytest.mark.parametrize(
378473
("kernel_transformer", "sample_sizes"),
379474
[

0 commit comments

Comments
 (0)