Skip to content

Commit cc63413

Browse files
Allow providing a different sampler to Craft.
Currently by default the fit() method of Craft uses a hardcoded sampler. This patch allows Craft to use a sampler given in parameter. Signed-off-by: Frederic Boisnard <frederic.boisnard@irt-saintexupery.com>
1 parent 5743025 commit cc63413

1 file changed

Lines changed: 17 additions & 3 deletions

File tree

xplique/concepts/craft.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
1717
from matplotlib import gridspec
1818

19-
from xplique.attributions.global_sensitivity_analysis import (HaltonSequenceRS, JansenEstimator)
19+
from xplique.attributions.global_sensitivity_analysis import \
20+
(HaltonSequenceRS, ScipySobolSequenceRS, LatinHypercubeRS, JansenEstimator)
2021
from xplique.plots.image import _clip_percentile
2122

2223
from ..types import Callable, Tuple, Optional, Union
@@ -144,6 +145,14 @@ class DisplayImportancesOrder(Enum):
144145
def __eq__(self, other):
145146
return self.value == other.value
146147

148+
class MaskSampler(Enum):
149+
HALTON = HaltonSequenceRS
150+
SOBOL = ScipySobolSequenceRS
151+
LATIN = LatinHypercubeRS
152+
153+
def __eq__(self, other):
154+
return self.value == other.value
155+
147156
class BaseCraft(BaseConceptExtractor, ABC):
148157
"""
149158
Base class implementing the CRAFT Concept Extraction Mechanism.
@@ -292,7 +301,10 @@ def transform(self, inputs : np.ndarray, activations : np.ndarray = None) -> np.
292301
coeffs_u = np.reshape(coeffs_u, (*original_shape, coeffs_u.shape[-1]))
293302
return coeffs_u
294303

295-
def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -> np.ndarray:
304+
def estimate_importance(self,
305+
inputs: np.ndarray = None,
306+
sampler: MaskSampler = MaskSampler.HALTON,
307+
nb_design: int = 32) -> np.ndarray:
296308
"""
297309
Estimates the importance of each concept for a given class, either globally
298310
on the whole dataset provided in the fit() method (in this case, inputs shall
@@ -305,6 +317,8 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -
305317
If None, then the inputs provided in the fit() method
306318
will be used (global importance of the whole dataset).
307319
Default is None.
320+
sampler
321+
The sampling method to use for masking. Default to MaskSampler.HALTON.
308322
nb_design
309323
The number of design to use for the importance estimation. Default is 32.
310324
@@ -323,7 +337,7 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -
323337

324338
coeffs_u = self.transform(inputs)
325339

326-
masks = HaltonSequenceRS()(self.number_of_concepts, nb_design = nb_design)
340+
masks = sampler.value()(self.number_of_concepts, nb_design = nb_design)
327341
estimator = JansenEstimator()
328342

329343
importances = []

0 commit comments

Comments
 (0)