6262 DEFAULT_FLOAT_DTYPE ,
6363 _check_ncores ,
6464 _mask_coverage_to_mask ,
65+ _mask_img_to_bool ,
6566 _p_to_logp_values ,
6667 mm2vox ,
6768 use_memmap ,
@@ -626,6 +627,10 @@ class ALESubtraction(PairwiseCBMAEstimator):
626627 vfwe_only : :obj:`bool`, default=True
627628 If True, only compute voxel-level null information. If False, also compute and retain
628629 cluster size and mass null distributions from the permutation maps.
630+ restrict_to_inference_mask : :obj:`bool`, default=False
631+ If True and directional inference maps are supplied to ``fit``, restrict permutation
632+ inference to the union of nonzero inference-map voxels. Observed group and contrast
633+ summary-statistic maps are still reported across the full estimator mask.
629634 memory : instance of :class:`joblib.Memory`, :obj:`str`, or :class:`pathlib.Path`
630635 Used to cache the output of a function. By default, no caching is done.
631636 If a :obj:`str` is given, it is the path to the caching directory.
@@ -684,6 +689,7 @@ def __init__(
684689 voxel_thresh = 0.001 ,
685690 low_memory = "auto" ,
686691 vfwe_only = True ,
692+ restrict_to_inference_mask = False ,
687693 memory = Memory (location = None , verbose = 0 ),
688694 memory_level = 0 ,
689695 n_cores = 1 ,
@@ -710,6 +716,7 @@ def __init__(
710716 self .voxel_thresh = voxel_thresh
711717 self .low_memory = low_memory
712718 self .vfwe_only = vfwe_only
719+ self .restrict_to_inference_mask = restrict_to_inference_mask
713720 self .n_cores = _check_ncores (n_cores )
714721 self ._permutation_parallel_backend = "threading"
715722 self ._low_memory_fraction = 0.5
@@ -839,6 +846,60 @@ def _compute_summarystat_est(self, ma_values):
839846 require_masked_csr (ma_values ) if sp_sparse .isspmatrix (ma_values ) else ma_values
840847 )
841848
849+ @staticmethod
850+ def _inference_union_mask (group1_mask , group2_mask ):
851+ """Build the voxel union for directional inference maps."""
852+ if group1_mask is None and group2_mask is None :
853+ return None
854+
855+ base = group1_mask if group1_mask is not None else group2_mask
856+ union_mask = np .zeros (base .shape , dtype = bool )
857+ if group1_mask is not None :
858+ union_mask |= group1_mask
859+ if group2_mask is not None :
860+ union_mask |= group2_mask
861+ if not np .any (union_mask ):
862+ raise ValueError (
863+ "Directional ALESubtraction inference requires at least one nonzero voxel in "
864+ "inference_map1 or inference_map2."
865+ )
866+ return union_mask
867+
868+ def _restrict_pairwise_ma_store (self , ma_store , union_mask ):
869+ """Slice a pairwise MA store to the inference union mask."""
870+ if union_mask is None :
871+ return ma_store
872+
873+ return _PairwiseMAStore (
874+ group1 = self ._slice_ma_group_columns (ma_store .group1 , union_mask ),
875+ group2 = self ._slice_ma_group_columns (ma_store .group2 , union_mask ),
876+ group1_stat = ma_store .group1_stat [union_mask ],
877+ group2_stat = ma_store .group2_stat [union_mask ],
878+ temp_files = [],
879+ )
880+
881+ @staticmethod
882+ def _slice_ma_group_columns (ma_group , column_mask ):
883+ """Slice CSR or chunked CSR MA maps to selected columns."""
884+ if isinstance (ma_group , _ChunkedCSRGroup ):
885+ chunks = [chunk [:, column_mask ] for chunk in ma_group .chunks ]
886+ return _ChunkedCSRGroup (
887+ chunks = chunks ,
888+ row_offsets = ma_group .row_offsets .copy (),
889+ shape = (ma_group .shape [0 ], int (np .count_nonzero (column_mask ))),
890+ )
891+ return ma_group [:, column_mask ]
892+
893+ @staticmethod
894+ def _scatter_to_full_mask (values , union_mask , fill_value = 0 ):
895+ """Scatter restricted masked values back to the full masker vector."""
896+ if union_mask is None :
897+ return values
898+
899+ full_values = np .full (union_mask .shape [0 ], fill_value , dtype = np .asarray (values ).dtype )
900+ full_values [union_mask ] = values
901+ return full_values
902+
842903 @use_memmap (LGR , n_files = 3 )
843904 def _fit (self , dataset1 , dataset2 ):
844905 self .dataset1 = dataset1
@@ -851,13 +912,28 @@ def _fit(self, dataset1, dataset2):
851912 group1_mask = None if inference_map1 is None else np .asarray (inference_map1 ) > 0
852913 group2_mask = None if inference_map2 is None else np .asarray (inference_map2 ) > 0
853914
915+ union_mask = None
916+ has_inference_mask = group1_mask is not None or group2_mask is not None
917+ if has_inference_mask :
918+ inference_union_mask = self ._inference_union_mask (group1_mask , group2_mask )
919+ if self .restrict_to_inference_mask :
920+ union_mask = inference_union_mask
921+
854922 with self ._managed_pairwise_ma_store (
855923 maps_key1 = "ma_maps1" ,
856924 coords_key1 = "coordinates1" ,
857925 maps_key2 = "ma_maps2" ,
858926 coords_key2 = "coordinates2" ,
859927 ) as ma_store :
860- diff_ale_values = ma_store .group1_stat - ma_store .group2_stat
928+ fit_store = self ._restrict_pairwise_ma_store (ma_store , union_mask )
929+ if union_mask is None :
930+ fit_group1_mask = group1_mask
931+ fit_group2_mask = group2_mask
932+ else :
933+ fit_group1_mask = None if group1_mask is None else group1_mask [union_mask ]
934+ fit_group2_mask = None if group2_mask is None else group2_mask [union_mask ]
935+ full_diff_ale_values = ma_store .group1_stat - ma_store .group2_stat
936+ diff_ale_values = fit_store .group1_stat - fit_store .group2_stat
861937
862938 try :
863939 if not self .vfwe_only :
@@ -866,17 +942,17 @@ def _fit(self, dataset1, dataset2):
866942 self .memmap_filenames [2 ],
867943 dtype = DEFAULT_FLOAT_DTYPE ,
868944 mode = "w+" ,
869- shape = (self .n_iters , ma_store .n_voxels ),
945+ shape = (self .n_iters , fit_store .n_voxels ),
870946 )
871947
872948 iter_abs_max , p_values , diff_signs = self ._run_null_permutations (
873- ma_store ,
949+ fit_store ,
874950 n_iters = self .n_iters ,
875951 n_cores = self .n_cores ,
876952 diff_ale_values = diff_ale_values ,
877953 iter_diff_values = iter_diff_values ,
878- group1_mask = group1_mask ,
879- group2_mask = group2_mask ,
954+ group1_mask = fit_group1_mask ,
955+ group2_mask = fit_group2_mask ,
880956 )
881957 self .null_distributions_ ["values_level-voxel_corr-fwe_method-montecarlo" ] = (
882958 iter_abs_max
@@ -890,6 +966,7 @@ def _fit(self, dataset1, dataset2):
890966 iter_diff_values ,
891967 voxel_thresh = self .voxel_thresh ,
892968 n_iters = self .n_iters ,
969+ union_mask = union_mask ,
893970 )
894971 self .null_distributions_ [
895972 "summary_stat_thresh_level-voxel_corr-fwe_method-montecarlo"
@@ -922,9 +999,12 @@ def _fit(self, dataset1, dataset2):
922999 z_tail = "one" if (group1_mask is not None or group2_mask is not None ) else "two"
9231000 z_arr = p_to_z (p_values , tail = z_tail ) * diff_signs
9241001 logp_arr = _p_to_logp_values (p_values , dtype = DEFAULT_FLOAT_DTYPE )
1002+ p_values = self ._scatter_to_full_mask (p_values , union_mask , fill_value = 1 )
1003+ z_arr = self ._scatter_to_full_mask (z_arr , union_mask )
1004+ logp_arr = self ._scatter_to_full_mask (logp_arr , union_mask )
9251005
9261006 maps = {
927- "stat_desc-group1MinusGroup2" : diff_ale_values ,
1007+ "stat_desc-group1MinusGroup2" : full_diff_ale_values ,
9281008 "p_desc-group1MinusGroup2" : p_values ,
9291009 "z_desc-group1MinusGroup2" : z_arr ,
9301010 "logp_desc-group1MinusGroup2" : logp_arr ,
@@ -1013,15 +1093,27 @@ def _run_null_permutations(
10131093
10141094 return iter_abs_max , p_values , diff_signs
10151095
1016- def _compute_cluster_nulls (self , iter_diff_values , voxel_thresh , n_iters ):
1096+ def _compute_cluster_nulls (self , iter_diff_values , voxel_thresh , n_iters , union_mask = None ):
10171097 """Compute cluster-forming threshold and cluster null summaries from permutation maps."""
1018- ss_thresh = np .quantile (np .abs (iter_diff_values ), 1 - voxel_thresh )
1098+ # When union_mask is provided, restrict the ss_thresh quantile and cluster stats to the
1099+ # masked region so that null clusters can only form where inference maps have signal.
1100+ is_restricted = union_mask is not None and iter_diff_values .shape [1 ] == union_mask .sum ()
1101+ if union_mask is not None and not is_restricted :
1102+ ss_thresh = np .quantile (np .abs (iter_diff_values [:, union_mask ]), 1 - voxel_thresh )
1103+ else :
1104+ ss_thresh = np .quantile (np .abs (iter_diff_values ), 1 - voxel_thresh )
10191105 conn = ndimage .generate_binary_structure (rank = 3 , connectivity = 1 )
10201106 iter_max_sizes = np .zeros (n_iters , dtype = DEFAULT_FLOAT_DTYPE )
10211107 iter_max_masses = np .zeros (n_iters , dtype = DEFAULT_FLOAT_DTYPE )
10221108
10231109 for i_iter in range (n_iters ):
1024- iter_map = self .masker .inverse_transform (iter_diff_values [i_iter , :]).get_fdata (
1110+ iter_vals = iter_diff_values [i_iter , :]
1111+ if is_restricted :
1112+ iter_vals = self ._scatter_to_full_mask (iter_vals , union_mask )
1113+ elif union_mask is not None :
1114+ iter_vals = iter_vals .copy ()
1115+ iter_vals [~ union_mask ] = 0
1116+ iter_map = self .masker .inverse_transform (iter_vals ).get_fdata (
10251117 dtype = DEFAULT_FLOAT_DTYPE
10261118 )
10271119 iter_max_sizes [i_iter ], iter_max_masses [i_iter ] = _calculate_cluster_measures (
@@ -1624,6 +1716,9 @@ def _probabilistic_map(self, dataset, target_n, seed):
16241716 sample_space = np .vstack (np .where (prior_img )).T .astype (np .int32 , copy = False )
16251717 rng = np .random .RandomState (seed )
16261718
1719+ # Precompute boolean mask once to avoid NiBabel round-trip in hot loops.
1720+ mask_arr = _mask_img_to_bool (estimator .masker .mask_img )
1721+
16271722 null_cluster_sizes = np .zeros (self .n_iters , dtype = np .int32 )
16281723 for i_iter in range (self .n_iters ):
16291724 if target_n < ma_maps .shape [0 ]:
@@ -1646,6 +1741,7 @@ def _probabilistic_map(self, dataset, target_n, seed):
16461741 estimator .masker ,
16471742 voxel_thresh = self .voxel_thresh ,
16481743 cluster_size_threshold = None ,
1744+ mask_arr = mask_arr ,
16491745 )
16501746
16511747 cluster_threshold = np .percentile (null_cluster_sizes , 100.0 * (1.0 - self .alpha ))
@@ -1664,6 +1760,7 @@ def _probabilistic_map(self, dataset, target_n, seed):
16641760 estimator .masker ,
16651761 voxel_thresh = self .voxel_thresh ,
16661762 cluster_size_threshold = cluster_threshold ,
1763+ mask_arr = mask_arr ,
16671764 )
16681765 prob_map += (z_values > 0 ).astype (DEFAULT_FLOAT_DTYPE , copy = False )
16691766 return (
0 commit comments