Skip to content

Commit 4da7952

Browse files
authored
[REF] improve speed of NiMARE implementation of JALE (#1075)
* improve speed of NiMARE implementation of JALE - add restrict_to_inference_mask to ALESubtraction to reduce the number of voxels being calculated over when doing permutations (note this makes CSR more of a burden than a help) - ContrastWorkflow accepts results and corrected results as inputs so it does not have to recompute the main effects - add generate_description=False for both ContrastWorkflow and ResampledStability - _threshold_z_clusters does not use nilearn for transforming - ResampledStability now reuses the full-dataset approximate null for the subsamples, instead of recalculating the null repeatedly for each subsample - gray-matter prior resampling is now cached * update the number of iterations
1 parent 7640aa3 commit 4da7952

10 files changed

Lines changed: 554 additions & 63 deletions

File tree

nimare/diagnostics.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@
2727
)
2828
from nimare.meta.ibma import IBMAEstimator
2929
from nimare.nimads import Studyset
30+
from nimare.results import MetaResult
3031
from nimare.studyset import normalize_collection
3132
from nimare.utils import (
3233
DEFAULT_FLOAT_DTYPE,
3334
_check_ncores,
3435
_filter_kwargs,
3536
_mask_coverage_to_null_ijk,
37+
_mask_img_to_bool,
3638
get_masker,
3739
mm2vox,
3840
)
@@ -715,6 +717,9 @@ class ResampledStability(NiMAREBase):
715717
n_cores : int, optional
716718
Number of cores to use for parallelization.
717719
If <=0, defaults to using all available cores. Default is 1.
720+
generate_description : bool, optional
721+
Whether to append boilerplate text and extract references for the returned result.
722+
Default is True.
718723
"""
719724

720725
def __init__(
@@ -730,6 +735,7 @@ def __init__(
730735
mask_coverage="gm",
731736
alpha=0.05,
732737
n_cores=1,
738+
generate_description=True,
733739
):
734740
if mask_coverage not in ("gm", "brain"):
735741
raise ValueError("mask_coverage must be 'gm' or 'brain'.")
@@ -746,6 +752,7 @@ def __init__(
746752
self.mask_coverage = mask_coverage
747753
self.alpha = alpha
748754
self.n_cores = _check_ncores(n_cores)
755+
self.generate_description = generate_description
749756

750757
def _resolve_subsets(self, n_studies):
751758
"""Build a replicate schedule in study-index space."""
@@ -815,17 +822,27 @@ def _fit_replicate(self, kept_ids, result):
815822
return self._extract_binary_support(replicate_result)
816823

817824
def _fit_cbma_subset_replicate(
818-
self, subset_idx, ma_maps, estimator, study_ids, cluster_threshold
825+
self,
826+
subset_idx,
827+
ma_maps,
828+
estimator,
829+
study_ids,
830+
cluster_threshold,
831+
precomputed_null=None,
832+
mask_arr=None,
819833
):
820834
"""Compute one CBMA replicate from cached MA maps for a retained-study subset."""
821835
subset_ma = ma_maps[subset_idx, :]
822836
subset_study_ids = study_ids[subset_idx]
823-
_, z_values = _approximate_z_from_ma(estimator, subset_ma, subset_study_ids)
837+
_, z_values = _approximate_z_from_ma(
838+
estimator, subset_ma, subset_study_ids, precomputed_null=precomputed_null
839+
)
824840
z_values, _ = _threshold_z_clusters(
825841
z_values,
826842
estimator.masker,
827843
voxel_thresh=self.voxel_thresh or 0.001,
828844
cluster_size_threshold=cluster_threshold,
845+
mask_arr=mask_arr,
829846
)
830847
return (z_values > 0).astype(DEFAULT_FLOAT_DTYPE, copy=False)
831848

@@ -846,16 +863,30 @@ def _cbma_subset_stability(self, result, subsets, target_n):
846863
estimator.masker, mask_coverage=self.mask_coverage
847864
).astype(np.int32, copy=False)
848865

866+
# Build the full-dataset approximate null once and reuse it for every
867+
# subsample and null-MA iteration (mirrors JALE's hx_conv reuse).
868+
full_null_temp = copy.deepcopy(estimator)
869+
full_null_temp.null_distributions_ = {}
870+
full_null_temp._prepare_subsample_null(ma_maps)
871+
full_null_temp._compute_approximate_z_values(ma_maps)
872+
precomputed_null = full_null_temp.null_distributions_
873+
874+
# Precompute boolean mask array once to avoid NiBabel round-trip in hot loops.
875+
mask_arr = _mask_img_to_bool(estimator.masker.mask_img)
876+
849877
rng = np.random.RandomState(self.random_state)
850878
null_cluster_sizes = np.zeros(montecarlo_iters, dtype=np.int32)
851879
for i_iter in range(montecarlo_iters):
852880
null_ma, subset_ids = estimator._generate_random_null_ma(target_n, sample_space, rng)
853-
_, null_z = _approximate_z_from_ma(estimator, null_ma, subset_ids)
881+
_, null_z = _approximate_z_from_ma(
882+
estimator, null_ma, subset_ids, precomputed_null=precomputed_null
883+
)
854884
_, null_cluster_sizes[i_iter] = _threshold_z_clusters(
855885
null_z,
856886
estimator.masker,
857887
voxel_thresh=cluster_forming_threshold,
858888
cluster_size_threshold=None,
889+
mask_arr=mask_arr,
859890
)
860891

861892
cluster_threshold = np.percentile(null_cluster_sizes, 100.0 * (1.0 - self.alpha))
@@ -865,7 +896,13 @@ def _cbma_subset_stability(self, result, subsets, target_n):
865896
for support in tqdm(
866897
Parallel(return_as="generator", n_jobs=self.n_cores)(
867898
delayed(self._fit_cbma_subset_replicate)(
868-
subset_idx, ma_maps, estimator, study_ids, cluster_threshold
899+
subset_idx,
900+
ma_maps,
901+
estimator,
902+
study_ids,
903+
cluster_threshold,
904+
precomputed_null=precomputed_null,
905+
mask_arr=mask_arr,
869906
)
870907
for subset_idx in subsets
871908
),
@@ -880,7 +917,7 @@ def _cbma_subset_stability(self, result, subsets, target_n):
880917

881918
def _finalize_result(self, result, stability_map, n_resamples_used, target_n_used):
882919
"""Attach stability map and summary table to a copied result object."""
883-
result = result.copy()
920+
result = self._copy_result_for_diagnostic(result)
884921
map_name = f"{self.target_image}_diag-ResampledStability"
885922
result.maps[map_name] = stability_map
886923
result.tables[f"{map_name}_tab-summary"] = pd.DataFrame(
@@ -896,14 +933,30 @@ def _finalize_result(self, result, stability_map, n_resamples_used, target_n_use
896933
]
897934
)
898935
result.diagnostics.append(self)
899-
result.description_ += (
900-
" Voxelwise stability of thresholded results was estimated by repeatedly "
901-
"resampling the input dataset, recomputing thresholded support maps, and averaging "
902-
"the binary support across resamples. This diagnostic follows the resampling-based "
903-
"stability approach implemented in JALE \\citep{Frahm_Monimu_Hoffstaedter}."
904-
)
936+
if self.generate_description:
937+
result.description_ += (
938+
" Voxelwise stability of thresholded results was estimated by repeatedly "
939+
"resampling the input dataset, recomputing thresholded support maps, and "
940+
"averaging the binary support across resamples. This diagnostic follows the "
941+
"resampling-based stability approach implemented "
942+
"in JALE \\citep{Frahm_Monimu_Hoffstaedter}."
943+
)
905944
return result
906945

946+
@staticmethod
947+
def _copy_result_for_diagnostic(result):
948+
"""Return a lightweight MetaResult copy suitable for adding diagnostic outputs."""
949+
new = object.__new__(MetaResult)
950+
new.estimator = result.estimator
951+
new.corrector = result.corrector
952+
new.diagnostics = list(result.diagnostics)
953+
new.masker = result.masker
954+
new.maps = dict(result.maps)
955+
new.tables = dict(result.tables)
956+
new.metadata = dict(getattr(result, "metadata", {}))
957+
new._set_description(result.description_)
958+
return new
959+
907960
def transform(self, result):
908961
"""Apply the resampling diagnostic to a fitted meta-analytic result."""
909962
if issubclass(type(result.estimator), PairwiseCBMAEstimator):

nimare/meta/cbma/ale.py

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
DEFAULT_FLOAT_DTYPE,
6363
_check_ncores,
6464
_mask_coverage_to_mask,
65+
_mask_img_to_bool,
6566
_p_to_logp_values,
6667
mm2vox,
6768
use_memmap,
@@ -626,6 +627,10 @@ class ALESubtraction(PairwiseCBMAEstimator):
626627
vfwe_only : :obj:`bool`, default=True
627628
If True, only compute voxel-level null information. If False, also compute and retain
628629
cluster size and mass null distributions from the permutation maps.
630+
restrict_to_inference_mask : :obj:`bool`, default=False
631+
If True and directional inference maps are supplied to ``fit``, restrict permutation
632+
inference to the union of nonzero inference-map voxels. Observed group and contrast
633+
summary-statistic maps are still reported across the full estimator mask.
629634
memory : instance of :class:`joblib.Memory`, :obj:`str`, or :class:`pathlib.Path`
630635
Used to cache the output of a function. By default, no caching is done.
631636
If a :obj:`str` is given, it is the path to the caching directory.
@@ -684,6 +689,7 @@ def __init__(
684689
voxel_thresh=0.001,
685690
low_memory="auto",
686691
vfwe_only=True,
692+
restrict_to_inference_mask=False,
687693
memory=Memory(location=None, verbose=0),
688694
memory_level=0,
689695
n_cores=1,
@@ -710,6 +716,7 @@ def __init__(
710716
self.voxel_thresh = voxel_thresh
711717
self.low_memory = low_memory
712718
self.vfwe_only = vfwe_only
719+
self.restrict_to_inference_mask = restrict_to_inference_mask
713720
self.n_cores = _check_ncores(n_cores)
714721
self._permutation_parallel_backend = "threading"
715722
self._low_memory_fraction = 0.5
@@ -839,6 +846,60 @@ def _compute_summarystat_est(self, ma_values):
839846
require_masked_csr(ma_values) if sp_sparse.isspmatrix(ma_values) else ma_values
840847
)
841848

849+
@staticmethod
850+
def _inference_union_mask(group1_mask, group2_mask):
851+
"""Build the voxel union for directional inference maps."""
852+
if group1_mask is None and group2_mask is None:
853+
return None
854+
855+
base = group1_mask if group1_mask is not None else group2_mask
856+
union_mask = np.zeros(base.shape, dtype=bool)
857+
if group1_mask is not None:
858+
union_mask |= group1_mask
859+
if group2_mask is not None:
860+
union_mask |= group2_mask
861+
if not np.any(union_mask):
862+
raise ValueError(
863+
"Directional ALESubtraction inference requires at least one nonzero voxel in "
864+
"inference_map1 or inference_map2."
865+
)
866+
return union_mask
867+
868+
def _restrict_pairwise_ma_store(self, ma_store, union_mask):
869+
"""Slice a pairwise MA store to the inference union mask."""
870+
if union_mask is None:
871+
return ma_store
872+
873+
return _PairwiseMAStore(
874+
group1=self._slice_ma_group_columns(ma_store.group1, union_mask),
875+
group2=self._slice_ma_group_columns(ma_store.group2, union_mask),
876+
group1_stat=ma_store.group1_stat[union_mask],
877+
group2_stat=ma_store.group2_stat[union_mask],
878+
temp_files=[],
879+
)
880+
881+
@staticmethod
882+
def _slice_ma_group_columns(ma_group, column_mask):
883+
"""Slice CSR or chunked CSR MA maps to selected columns."""
884+
if isinstance(ma_group, _ChunkedCSRGroup):
885+
chunks = [chunk[:, column_mask] for chunk in ma_group.chunks]
886+
return _ChunkedCSRGroup(
887+
chunks=chunks,
888+
row_offsets=ma_group.row_offsets.copy(),
889+
shape=(ma_group.shape[0], int(np.count_nonzero(column_mask))),
890+
)
891+
return ma_group[:, column_mask]
892+
893+
@staticmethod
894+
def _scatter_to_full_mask(values, union_mask, fill_value=0):
895+
"""Scatter restricted masked values back to the full masker vector."""
896+
if union_mask is None:
897+
return values
898+
899+
full_values = np.full(union_mask.shape[0], fill_value, dtype=np.asarray(values).dtype)
900+
full_values[union_mask] = values
901+
return full_values
902+
842903
@use_memmap(LGR, n_files=3)
843904
def _fit(self, dataset1, dataset2):
844905
self.dataset1 = dataset1
@@ -851,13 +912,28 @@ def _fit(self, dataset1, dataset2):
851912
group1_mask = None if inference_map1 is None else np.asarray(inference_map1) > 0
852913
group2_mask = None if inference_map2 is None else np.asarray(inference_map2) > 0
853914

915+
union_mask = None
916+
has_inference_mask = group1_mask is not None or group2_mask is not None
917+
if has_inference_mask:
918+
inference_union_mask = self._inference_union_mask(group1_mask, group2_mask)
919+
if self.restrict_to_inference_mask:
920+
union_mask = inference_union_mask
921+
854922
with self._managed_pairwise_ma_store(
855923
maps_key1="ma_maps1",
856924
coords_key1="coordinates1",
857925
maps_key2="ma_maps2",
858926
coords_key2="coordinates2",
859927
) as ma_store:
860-
diff_ale_values = ma_store.group1_stat - ma_store.group2_stat
928+
fit_store = self._restrict_pairwise_ma_store(ma_store, union_mask)
929+
if union_mask is None:
930+
fit_group1_mask = group1_mask
931+
fit_group2_mask = group2_mask
932+
else:
933+
fit_group1_mask = None if group1_mask is None else group1_mask[union_mask]
934+
fit_group2_mask = None if group2_mask is None else group2_mask[union_mask]
935+
full_diff_ale_values = ma_store.group1_stat - ma_store.group2_stat
936+
diff_ale_values = fit_store.group1_stat - fit_store.group2_stat
861937

862938
try:
863939
if not self.vfwe_only:
@@ -866,17 +942,17 @@ def _fit(self, dataset1, dataset2):
866942
self.memmap_filenames[2],
867943
dtype=DEFAULT_FLOAT_DTYPE,
868944
mode="w+",
869-
shape=(self.n_iters, ma_store.n_voxels),
945+
shape=(self.n_iters, fit_store.n_voxels),
870946
)
871947

872948
iter_abs_max, p_values, diff_signs = self._run_null_permutations(
873-
ma_store,
949+
fit_store,
874950
n_iters=self.n_iters,
875951
n_cores=self.n_cores,
876952
diff_ale_values=diff_ale_values,
877953
iter_diff_values=iter_diff_values,
878-
group1_mask=group1_mask,
879-
group2_mask=group2_mask,
954+
group1_mask=fit_group1_mask,
955+
group2_mask=fit_group2_mask,
880956
)
881957
self.null_distributions_["values_level-voxel_corr-fwe_method-montecarlo"] = (
882958
iter_abs_max
@@ -890,6 +966,7 @@ def _fit(self, dataset1, dataset2):
890966
iter_diff_values,
891967
voxel_thresh=self.voxel_thresh,
892968
n_iters=self.n_iters,
969+
union_mask=union_mask,
893970
)
894971
self.null_distributions_[
895972
"summary_stat_thresh_level-voxel_corr-fwe_method-montecarlo"
@@ -922,9 +999,12 @@ def _fit(self, dataset1, dataset2):
922999
z_tail = "one" if (group1_mask is not None or group2_mask is not None) else "two"
9231000
z_arr = p_to_z(p_values, tail=z_tail) * diff_signs
9241001
logp_arr = _p_to_logp_values(p_values, dtype=DEFAULT_FLOAT_DTYPE)
1002+
p_values = self._scatter_to_full_mask(p_values, union_mask, fill_value=1)
1003+
z_arr = self._scatter_to_full_mask(z_arr, union_mask)
1004+
logp_arr = self._scatter_to_full_mask(logp_arr, union_mask)
9251005

9261006
maps = {
927-
"stat_desc-group1MinusGroup2": diff_ale_values,
1007+
"stat_desc-group1MinusGroup2": full_diff_ale_values,
9281008
"p_desc-group1MinusGroup2": p_values,
9291009
"z_desc-group1MinusGroup2": z_arr,
9301010
"logp_desc-group1MinusGroup2": logp_arr,
@@ -1013,15 +1093,27 @@ def _run_null_permutations(
10131093

10141094
return iter_abs_max, p_values, diff_signs
10151095

1016-
def _compute_cluster_nulls(self, iter_diff_values, voxel_thresh, n_iters):
1096+
def _compute_cluster_nulls(self, iter_diff_values, voxel_thresh, n_iters, union_mask=None):
10171097
"""Compute cluster-forming threshold and cluster null summaries from permutation maps."""
1018-
ss_thresh = np.quantile(np.abs(iter_diff_values), 1 - voxel_thresh)
1098+
# When union_mask is provided, restrict the ss_thresh quantile and cluster stats to the
1099+
# masked region so that null clusters can only form where inference maps have signal.
1100+
is_restricted = union_mask is not None and iter_diff_values.shape[1] == union_mask.sum()
1101+
if union_mask is not None and not is_restricted:
1102+
ss_thresh = np.quantile(np.abs(iter_diff_values[:, union_mask]), 1 - voxel_thresh)
1103+
else:
1104+
ss_thresh = np.quantile(np.abs(iter_diff_values), 1 - voxel_thresh)
10191105
conn = ndimage.generate_binary_structure(rank=3, connectivity=1)
10201106
iter_max_sizes = np.zeros(n_iters, dtype=DEFAULT_FLOAT_DTYPE)
10211107
iter_max_masses = np.zeros(n_iters, dtype=DEFAULT_FLOAT_DTYPE)
10221108

10231109
for i_iter in range(n_iters):
1024-
iter_map = self.masker.inverse_transform(iter_diff_values[i_iter, :]).get_fdata(
1110+
iter_vals = iter_diff_values[i_iter, :]
1111+
if is_restricted:
1112+
iter_vals = self._scatter_to_full_mask(iter_vals, union_mask)
1113+
elif union_mask is not None:
1114+
iter_vals = iter_vals.copy()
1115+
iter_vals[~union_mask] = 0
1116+
iter_map = self.masker.inverse_transform(iter_vals).get_fdata(
10251117
dtype=DEFAULT_FLOAT_DTYPE
10261118
)
10271119
iter_max_sizes[i_iter], iter_max_masses[i_iter] = _calculate_cluster_measures(
@@ -1624,6 +1716,9 @@ def _probabilistic_map(self, dataset, target_n, seed):
16241716
sample_space = np.vstack(np.where(prior_img)).T.astype(np.int32, copy=False)
16251717
rng = np.random.RandomState(seed)
16261718

1719+
# Precompute boolean mask once to avoid NiBabel round-trip in hot loops.
1720+
mask_arr = _mask_img_to_bool(estimator.masker.mask_img)
1721+
16271722
null_cluster_sizes = np.zeros(self.n_iters, dtype=np.int32)
16281723
for i_iter in range(self.n_iters):
16291724
if target_n < ma_maps.shape[0]:
@@ -1646,6 +1741,7 @@ def _probabilistic_map(self, dataset, target_n, seed):
16461741
estimator.masker,
16471742
voxel_thresh=self.voxel_thresh,
16481743
cluster_size_threshold=None,
1744+
mask_arr=mask_arr,
16491745
)
16501746

16511747
cluster_threshold = np.percentile(null_cluster_sizes, 100.0 * (1.0 - self.alpha))
@@ -1664,6 +1760,7 @@ def _probabilistic_map(self, dataset, target_n, seed):
16641760
estimator.masker,
16651761
voxel_thresh=self.voxel_thresh,
16661762
cluster_size_threshold=cluster_threshold,
1763+
mask_arr=mask_arr,
16671764
)
16681765
prob_map += (z_values > 0).astype(DEFAULT_FLOAT_DTYPE, copy=False)
16691766
return (

0 commit comments

Comments
 (0)