1616from matplotlib .colors import ListedColormap , LinearSegmentedColormap
1717from matplotlib import gridspec
1818
19- from xplique .attributions .global_sensitivity_analysis import (HaltonSequenceRS , JansenEstimator )
19+ from xplique .attributions .global_sensitivity_analysis import \
20+ (HaltonSequenceRS , ScipySobolSequenceRS , LatinHypercubeRS , JansenEstimator )
2021from xplique .plots .image import _clip_percentile
2122
2223from ..types import Callable , Tuple , Optional , Union
@@ -144,6 +145,14 @@ class DisplayImportancesOrder(Enum):
144145 def __eq__ (self , other ):
145146 return self .value == other .value
146147
148+ class MaskSampler (Enum ):
149+ HALTON = HaltonSequenceRS
150+ SOBOL = ScipySobolSequenceRS
151+ LATIN = LatinHypercubeRS
152+
153+ def __eq__ (self , other ):
154+ return self .value == other .value
155+
147156class BaseCraft (BaseConceptExtractor , ABC ):
148157 """
149158 Base class implementing the CRAFT Concept Extraction Mechanism.
@@ -292,7 +301,10 @@ def transform(self, inputs : np.ndarray, activations : np.ndarray = None) -> np.
292301 coeffs_u = np .reshape (coeffs_u , (* original_shape , coeffs_u .shape [- 1 ]))
293302 return coeffs_u
294303
295- def estimate_importance (self , inputs : np .ndarray = None , nb_design : int = 32 ) -> np .ndarray :
304+ def estimate_importance (self ,
305+ inputs : np .ndarray = None ,
306+ sampler : MaskSampler = MaskSampler .HALTON ,
307+ nb_design : int = 32 ) -> np .ndarray :
296308 """
297309 Estimates the importance of each concept for a given class, either globally
298310 on the whole dataset provided in the fit() method (in this case, inputs shall
@@ -305,6 +317,8 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -
305317 If None, then the inputs provided in the fit() method
306318 will be used (global importance of the whole dataset).
307319 Default is None.
320+ sampler
321+ The sampling method to use for masking. Default to MaskSampler.HALTON.
308322 nb_design
309323 The number of design to use for the importance estimation. Default is 32.
310324
@@ -323,7 +337,7 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -
323337
324338 coeffs_u = self .transform (inputs )
325339
326- masks = HaltonSequenceRS ()(self .number_of_concepts , nb_design = nb_design )
340+ masks = sampler . value ()(self .number_of_concepts , nb_design = nb_design )
327341 estimator = JansenEstimator ()
328342
329343 importances = []
0 commit comments