diff --git a/tests/attributions/test_nlp_occlusion.py b/tests/attributions/test_nlp_occlusion.py
new file mode 100644
index 00000000..d1b60b20
--- /dev/null
+++ b/tests/attributions/test_nlp_occlusion.py
@@ -0,0 +1,47 @@
+"""
+Test object detection BoundingBoxesExplainer
+"""
+import numpy as np
+
+from xplique.attributions import NlpOcclusion
+
+def test_masks():
+ """Test the masks creation"""
+ sentence = "aaa bbb ccc"
+ words = sentence.split(" ")
+ masks = NlpOcclusion._get_masks(words)
+ assert masks.shape == (len(words), len(words))
+ expected_mask = np.array([[False, True, True],
+ [True, False, True],
+ [True, True, False]])
+
+ assert np.array_equal(masks, expected_mask)
+
+def test_apply_masks():
+ """Test if the application of a mask generate valid results"""
+ sentence = "aaa bbb ccc"
+ words = sentence.split(" ")
+ masks = NlpOcclusion._get_masks(words)
+
+ occluded_inputs = NlpOcclusion._apply_masks(words, masks)
+ expected_occludec_inputs = [['bbb', 'ccc'], ['aaa', 'ccc'], ['aaa', 'bbb']]
+ assert np.array_equal(occluded_inputs, expected_occludec_inputs)
+
+def test_output_shape():
+ """Test the output shape for several input sentences"""
+
+ nb_concepts = 10
+
+ def transform(inputs):
+ # simulate the transorm method used in Craft/Cockatiel
+ return np.ones((len(inputs), nb_concepts))
+
+ input_sentence = ["aaa bbb ccc ddd eee fff", "ggg hhh iii jjj"]
+ for sentence in input_sentence:
+ words = sentence.split(" ")
+ separator = " "
+
+ method = NlpOcclusion(model=transform)
+ sensitivity = method.explain(sentence, words, separator)
+
+ assert sensitivity.shape == (nb_concepts, len(words))
diff --git a/tests/concepts/test_cockatiel.py b/tests/concepts/test_cockatiel.py
new file mode 100644
index 00000000..6f4e5e57
--- /dev/null
+++ b/tests/concepts/test_cockatiel.py
@@ -0,0 +1,240 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from torch.nn import MSELoss
+from transformers import RobertaPreTrainedModel, RobertaModel
+from transformers import RobertaTokenizerFast
+from xplique.concepts import CockatielTorch as Cockatiel
+from xplique.commons.torch_operations import NlpPreprocessor
+
+
+class FakeImdbClassifier(torch.nn.Module):
+ def __init__(self, nb_classes):
+ super().__init__()
+ self.nb_classes = nb_classes
+
+ def features(self, **kwargs):
+ start_value = 0
+ return torch.arange(start_value,
+ start_value + kwargs['input_ids'].shape[0] * self.nb_classes).\
+ reshape(kwargs['input_ids'].shape[0], self.nb_classes)
+
+ def classifier(self, latent):
+ # Simulates a classifier, returns alternatively [0, 1] and [1, 0]
+ # as result
+ n = len(latent)
+ pattern = torch.cat([torch.tensor([0, 1]), torch.tensor([1, 0])])
+ ypred = torch.cat([pattern] * n).reshape(n * 2, -1)
+ return ypred
+
+ def forward(self, **kwargs):
+ return self.classifier(self.features(**kwargs))
+
+
+class ImdbPreprocessor(NlpPreprocessor):
+ def preprocess(self, inputs: np.ndarray, labels: np.ndarray):
+ preprocessed_inputs = self.tokenize(samples=inputs.tolist())
+ preprocessed_labels = torch.Tensor(np.array(
+ labels.tolist()) == 'positive').int().to(self.device)
+ return preprocessed_inputs, preprocessed_labels
+
+
+def test_shape():
+ """Ensure the output shape is correct"""
+
+ device = torch.device(
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
+ pretrained_model_path = "wrmurray/roberta-base-finetuned-imdb"
+ tokenizer = RobertaTokenizerFast.from_pretrained(pretrained_model_path)
+ nb_classes = 2
+ model = FakeImdbClassifier(nb_classes=nb_classes)
+ model = model.eval()
+
+ imdb_preprocessor = ImdbPreprocessor(tokenizer,
+ device,
+ padding="max_length",
+ max_length=512,
+ truncation=True,
+ return_tensors='pt')
+
+ input_to_latent_model = model.features
+ latent_to_logit_model = model.classifier
+ number_of_concepts = 20
+ cockatiel_explainer_pos = Cockatiel(input_to_latent_model=input_to_latent_model,
+ latent_to_logit_model=latent_to_logit_model,
+ preprocessor=imdb_preprocessor,
+ number_of_concepts=number_of_concepts,
+ batch_size=64,
+ device=device)
+
+ data = ["Absolutely riveting from beginning to end. Test 1 2 3. Another test.",
+ "But it's a great movie.",
+ "This is the best movie of the year to date.",
+ "The movie is excellent."]
+ crops, crops_u, concept_bank_w \
+ = cockatiel_explainer_pos.fit(inputs=data,
+ class_id=1, alpha_w=0)
+ assert len(crops) == 6
+ assert crops_u.shape == (6, number_of_concepts)
+ assert concept_bank_w.shape == (number_of_concepts, nb_classes)
+
+ global_importance_pos = cockatiel_explainer_pos.estimate_importance()
+ assert len(global_importance_pos) == number_of_concepts
+
+
+class CustomRobertaForSequenceClassification(RobertaPreTrainedModel):
+ """
+ A custom RoBERTa model using a custom fully-connected head with a non-negative layer on which
+ we can compute the NMF.
+
+ Parameters
+ ----------
+ config
+ An object indicating the hidden layer size, the presence and amount of dropout for the
+ classification head.
+ """
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = 1
+ self.config = config
+
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None
+ else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Linear(config.hidden_size, 2)
+
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
+
+ self.mse_loss = MSELoss()
+
+ self.post_init()
+
+ def classifier_features(self, x):
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = torch.relu(x)
+ return x
+
+ def classifier_end_model(self, x):
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+ def classifier(self, x):
+ x = self.classifier_features(x)
+ x = self.classifier_end_model(x)
+ return x
+
+ def features(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ # chain RobertaModel and the classifier features ending with a ReLU
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0][:, 0, :]
+ return self.classifier_features(sequence_output)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ activations = self.features(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ logits = self.classifier_end_model(activations)
+ return logits
+
+
+def test_classifier():
+
+ device = torch.device(
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
+ pretrained_model_path = "wrmurray/roberta-base-finetuned-imdb"
+ tokenizer = RobertaTokenizerFast.from_pretrained(pretrained_model_path)
+ model = CustomRobertaForSequenceClassification.from_pretrained(
+ pretrained_model_path).to(device)
+ model = model.eval()
+
+ imdb_preprocessor = ImdbPreprocessor(tokenizer,
+ device,
+ padding="max_length",
+ max_length=512,
+ truncation=True,
+ return_tensors='pt')
+
+ input_to_latent_model = model.features
+ latent_to_logit_model = model.classifier
+ number_of_concepts = 2
+ cockatiel_explainer_pos = Cockatiel(input_to_latent_model=input_to_latent_model,
+ latent_to_logit_model=latent_to_logit_model,
+ preprocessor=imdb_preprocessor,
+ number_of_concepts=number_of_concepts,
+ batch_size=64,
+ device=device)
+
+ data = ["Absolutely great riveting from beginning to end. It was great. Great great great.",
+ "But it's a great movie.",
+ "I liked this film.",
+ "I enjoyed the scenario of this movie."] # sentence with 'great', othes not
+ crops, crops_u, concept_bank_w \
+ = cockatiel_explainer_pos.fit(inputs=data,
+ class_id=1, alpha_w=0)
+ assert len(crops) == 6
+ assert crops_u.shape == (6, number_of_concepts)
+ assert concept_bank_w.shape[0] == number_of_concepts
+
+ global_importance_pos = cockatiel_explainer_pos.estimate_importance()
+ assert len(global_importance_pos) == number_of_concepts
+ most_important_concepts_ids = global_importance_pos.argsort()
+
+ nb_excerpts = 2
+ nb_most_important_concepts = 2
+
+ best_sentences_per_concept = cockatiel_explainer_pos.get_best_excerpts_per_concept(
+ nb_excerpts=nb_excerpts, nb_most_important_concepts=nb_most_important_concepts)
+ assert np.all(best_sentences_per_concept[most_important_concepts_ids[0]] == [
+ 'I enjoyed the scenario of this movie.', 'I liked this film.'])
+ assert np.all(best_sentences_per_concept[most_important_concepts_ids[1]] == [
+ 'Great great great.', 'It was great.'])
diff --git a/tests/nlp/__init__.py b/tests/nlp/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/nlp/test_preprocessor.py b/tests/nlp/test_preprocessor.py
new file mode 100644
index 00000000..9c888b68
--- /dev/null
+++ b/tests/nlp/test_preprocessor.py
@@ -0,0 +1,96 @@
+import numpy as np
+
+import torch
+import pytest
+from transformers import RobertaTokenizerFast
+from xplique.commons.torch_operations import NlpPreprocessor, batcher, nlp_batch_predict
+
+
+class ImdbPreprocessor(NlpPreprocessor):
+ def preprocess(self, inputs: np.ndarray, labels: np.ndarray):
+ preprocessed_inputs = self.tokenize(samples=inputs.tolist())
+ preprocessed_labels = torch.Tensor(np.array(
+ labels.tolist()) == 'positive').int().to(self.device)
+ return preprocessed_inputs, preprocessed_labels
+
+
+@pytest.fixture
+def imdb_preprocessor():
+ pretrained_model_path = "wrmurray/roberta-base-finetuned-imdb"
+ tokenizer = RobertaTokenizerFast.from_pretrained(pretrained_model_path)
+ return ImdbPreprocessor(tokenizer,
+ device='cpu',
+ padding="max_length",
+ max_length=512,
+ truncation=True,
+ return_tensors='pt')
+
+
+def test_imdb_preprocessor(imdb_preprocessor):
+ inputs = np.array(['This is an example sentence.',
+ 'Another sentence for testing.'])
+ labels = np.array(['positive', 'negative'])
+ x_preprocessed, y_preprocessed = imdb_preprocessor.preprocess(
+ inputs, labels)
+ assert 'input_ids' in x_preprocessed
+ assert 'attention_mask' in x_preprocessed
+ assert x_preprocessed['input_ids'].shape == x_preprocessed['attention_mask'].shape == (
+ 2, 512)
+ assert y_preprocessed.shape == (2,)
+ assert y_preprocessed[0] == 1
+ assert y_preprocessed[1] == 0
+
+
+def test_batcher():
+ # Test case 1: elements perfectly divisible by batch_size
+ elements1 = [1, 2, 3, 4, 5, 6]
+ batch_size1 = 2
+ batches1 = list(batcher(elements1, batch_size1))
+ assert batches1 == [[1, 2], [3, 4], [5, 6]]
+
+ # Test case 2: elements not perfectly divisible by batch_size
+ elements2 = [1, 2, 3, 4, 5]
+ batch_size2 = 2
+ batches2 = list(batcher(elements2, batch_size2))
+ assert batches2 == [[1, 2], [3, 4], [5]]
+
+ # Test case 3: batch_size greater than the length of elements
+ elements3 = [1, 2, 3, 4, 5]
+ batch_size3 = 10
+ batches3 = list(batcher(elements3, batch_size3))
+ assert batches3 == [[1, 2, 3, 4, 5]]
+
+ # Test case 4: batch_size equal to the length of elements
+ elements4 = [1, 2, 3, 4, 5]
+ batch_size4 = 5
+ batches4 = list(batcher(elements4, batch_size4))
+ assert batches4 == [[1, 2, 3, 4, 5]]
+
+ # Test case 5: with a complex list of elements
+ elements5 = list(zip([1, 2, 3, 4, 5], ['a', 'b', 'c', 'd', 'e']))
+ batch_size5 = 2
+ batches5 = list(batcher(elements5, batch_size5))
+ assert batches5 == [[(1, 'a'), (2, 'b')], [(3, 'c'), (4, 'd')], [(5, 'e')]]
+
+
+class ImdbClassifier(torch.nn.Module):
+ def __init__(self, nb_classes):
+ super().__init__()
+ self.nb_classes = nb_classes
+
+ def forward(self, **kwargs):
+ return torch.randn(kwargs['input_ids'].shape[0], self.nb_classes)
+
+
+def test_batch_predict(imdb_preprocessor):
+ inputs = ['text1', 'text2', 'text3']
+ labels = ['positive', 'negative', 'positive']
+
+ batch_size = 2
+ nb_classes = 2
+ mock_model = ImdbClassifier(nb_classes)
+
+ predictions, processed_labels = nlp_batch_predict(
+ mock_model, imdb_preprocessor, inputs, labels, batch_size)
+
+ assert len(predictions) == len(processed_labels) == len(inputs)
diff --git a/tests/nlp/test_token_extractor.py b/tests/nlp/test_token_extractor.py
new file mode 100644
index 00000000..7907390f
--- /dev/null
+++ b/tests/nlp/test_token_extractor.py
@@ -0,0 +1,203 @@
+from xplique.commons.nlp import WordExtractor, SentenceExtractor
+from xplique.commons.nlp import FlairClauseExtractor, ExcerptExtractor, ExtractorFactory, SpacyClauseExtractor
+from flair.models import SequenceTagger
+
+import spacy
+
+import pytest
+
+
+@pytest.fixture
+def example_sentence():
+ return "One two three. Second sentence.Third Sentence, test1, test2; test3-test4 .GO!"\
+ " Trust me,. sentence not starting with capital letter."\
+ " Sentence with dots..Word, Word, word,...word ....so a sentence"
+
+
+@pytest.fixture
+def chunk_english_tagger():
+ return SequenceTagger.load("flair/chunk-english")
+
+
+@pytest.fixture
+def web_sm_pipeline():
+ spacy.cli.download("en_core_web_sm")
+ return spacy.load("en_core_web_sm")
+
+
+def test_word_extractor(example_sentence):
+ extractor = WordExtractor()
+ tokens, separator = extractor.extract_tokens(example_sentence)
+ assert isinstance(tokens, list)
+ assert isinstance(separator, str)
+ assert separator == ' '
+ expected_tokens = ['One', 'two', 'three', '.',
+ 'Second', 'sentence.Third', 'Sentence', ',', 'test1', ',', 'test2', ';',
+ 'test3-test4', '.GO', '!', 'Trust', 'me', ',', '.', 'sentence', 'not',
+ 'starting', 'with', 'capital', 'letter', '.', 'Sentence', 'with', 'dots',
+ '..', 'Word', ',', 'Word', ',', 'word', ',', '...', 'word', '....', 'so',
+ 'a', 'sentence']
+ assert tokens == expected_tokens, print('tokens:', tokens)
+
+
+def test_word_extractor_ignore_words(example_sentence):
+ extractor = WordExtractor()
+ tokens, separator = extractor.extract_tokens(example_sentence)
+ assert isinstance(tokens, list)
+ assert isinstance(separator, str)
+ assert separator == ' '
+ expected_tokens = ['One', 'two', 'three', '.',
+ 'Second', 'sentence.Third', 'Sentence', ',', 'test1', ',', 'test2', ';',
+ 'test3-test4', '.GO', '!', 'Trust', 'me', ',', '.', 'sentence', 'not',
+ 'starting', 'with', 'capital', 'letter', '.', 'Sentence', 'with', 'dots',
+ '..', 'Word', ',', 'Word', ',', 'word', ',', '...', 'word', '....',
+ 'so', 'a', 'sentence']
+ assert tokens == expected_tokens, print('tokens:', tokens)
+
+
+def test_word_extractor_from_list(example_sentence):
+ extractor = WordExtractor()
+ tokens, separator = extractor.extract_tokens(
+ [example_sentence, example_sentence])
+ assert isinstance(tokens, list)
+ assert isinstance(separator, str)
+ assert separator == ' '
+
+
+def test_sentence_extractor(example_sentence):
+ extractor = SentenceExtractor()
+ tokens, separator = extractor.extract_tokens(example_sentence)
+ assert isinstance(tokens, list)
+ assert isinstance(separator, str)
+ assert separator == '. '
+ expected_tokens = ['One two three.',
+ 'Second sentence.Third Sentence, test1, test2; test3-test4 .GO!',
+ 'Trust me,.',
+ 'sentence not starting with capital letter.',
+ 'Sentence with dots..Word, Word, word,...word ....so a sentence']
+ assert tokens == expected_tokens, print('tokens:', tokens)
+
+
+def test_excerpt_extractor(example_sentence):
+ extractor = ExcerptExtractor()
+ tokens, separator = extractor.extract_tokens(example_sentence)
+ assert isinstance(tokens, list)
+ assert isinstance(separator, str)
+ assert separator == ' '
+ expected_tokens = ['One two three.',
+ 'Second sentence.',
+ 'Third Sentence, test1, test2; test3-test4 .',
+ 'GO!',
+ 'Trust me,.',
+ 'Sentence with dots.',
+ 'Word, Word, word,.']
+ assert tokens == expected_tokens, print('tokens:', tokens)
+
+
+def test_flair_clause_extractor_close_type_none(example_sentence, chunk_english_tagger):
+ clause_extractor = FlairClauseExtractor(
+ tagger=chunk_english_tagger, clause_type=None)
+ tokens, separator = clause_extractor.extract_tokens(example_sentence)
+ assert isinstance(tokens, list)
+ assert isinstance(separator, str)
+ expected_tokens = ['One two three',
+ 'Second sentence.Third Sentence',
+ 'test1',
+ 'test2',
+ 'test3-test4',
+ 'GO',
+ 'Trust',
+ 'me',
+ 'sentence',
+ 'not starting',
+ 'with',
+ 'capital letter',
+ 'Sentence',
+ 'with',
+ 'dots.',
+ 'Word',
+ 'Word, word,...word',
+ 'a sentence']
+ assert tokens == expected_tokens
+
+
+def test_flair_clause_extractor_close_type_NP(example_sentence, chunk_english_tagger):
+ clause_extractor = FlairClauseExtractor(
+ tagger=chunk_english_tagger, clause_type=['NP'])
+ tokens, separator = clause_extractor.extract_tokens(example_sentence)
+ assert isinstance(tokens, list)
+ assert isinstance(separator, str)
+ expected_tokens = ['One two three',
+ 'Second sentence.Third Sentence',
+ 'test1',
+ 'test2',
+ 'test3-test4',
+ 'me',
+ 'sentence',
+ 'capital letter',
+ 'Sentence',
+ 'dots.',
+ 'Word',
+ 'Word, word,...word',
+ 'a sentence']
+ assert tokens == expected_tokens
+
+
+def test_clause_extractor_close_type_ADJP(example_sentence, chunk_english_tagger):
+ clause_extractor = FlairClauseExtractor(
+ tagger=chunk_english_tagger, clause_type=['ADJP'])
+ tokens, separator = clause_extractor.extract_tokens(example_sentence)
+ assert isinstance(tokens, list)
+ assert isinstance(separator, str)
+ print(tokens)
+ expected_tokens = []
+ assert tokens == expected_tokens
+
+
+def test_spacy_clause_extractor_close_type_none(example_sentence, web_sm_pipeline):
+ clause_extractor = SpacyClauseExtractor(
+ pipeline=web_sm_pipeline, clause_type=None)
+ tokens, separator = clause_extractor.extract_tokens(example_sentence)
+ assert isinstance(tokens, list)
+ assert isinstance(separator, str)
+ expected_tokens = ['One', 'two', 'three', '.', 'Second', 'sentence', '.',
+ 'Third', 'Sentence', ',', 'test1', ',', 'test2', ';',
+ 'test3', '-', 'test4', '.GO', '!', 'Trust', 'me', ',',
+ '.', 'sentence', 'not', 'starting', 'with', 'capital',
+ 'letter', '.', 'Sentence', 'with', 'dots', '..', 'Word',
+ ',', 'Word', ',', 'word,', '...', 'word', '....', 'so', 'a', 'sentence']
+ assert tokens == expected_tokens
+
+
+def test_spacy_clause_extractor_close_type_NN(example_sentence, web_sm_pipeline):
+ clause_extractor = SpacyClauseExtractor(
+ pipeline=web_sm_pipeline, clause_type=['NN'])
+ tokens, separator = clause_extractor.extract_tokens(example_sentence)
+ assert isinstance(tokens, list)
+ assert isinstance(separator, str)
+ expected_tokens = ['sentence', 'test1', 'test3', 'test4', 'sentence',
+ 'capital', 'letter', 'Sentence', '...', 'word', 'sentence']
+
+ assert tokens == expected_tokens
+
+
+def test_extractor_factory():
+ word_extractor = ExtractorFactory.get_extractor(extract_fct="word")
+ assert isinstance(word_extractor, WordExtractor)
+
+ sentence_extractor = ExtractorFactory.get_extractor(extract_fct="sentence")
+ assert isinstance(sentence_extractor, SentenceExtractor)
+
+ excerpt_extractor = ExtractorFactory.get_extractor(extract_fct="excerpt")
+ assert isinstance(excerpt_extractor, ExcerptExtractor)
+
+ flair_clause_extractor = ExtractorFactory.get_extractor(
+ extract_fct="flair_clause", clause_type=['NP'], tagger=None)
+ assert isinstance(flair_clause_extractor, FlairClauseExtractor)
+
+ spacy_clause_extractor = ExtractorFactory.get_extractor(
+ extract_fct="spacy_clause", clause_type=['NP'], pipeline=None)
+ assert isinstance(spacy_clause_extractor, SpacyClauseExtractor)
+
+ with pytest.raises(ValueError):
+ ExtractorFactory.get_extractor(extract_fct="invalid")
diff --git a/xplique/attributions/__init__.py b/xplique/attributions/__init__.py
index 2b5169f3..ae6581de 100644
--- a/xplique/attributions/__init__.py
+++ b/xplique/attributions/__init__.py
@@ -16,4 +16,5 @@
from .object_detector import BoundingBoxesExplainer
from .global_sensitivity_analysis import SobolAttributionMethod, HsicAttributionMethod
from .gradient_statistics import SmoothGrad, VarGrad, SquareGrad
+from .nlp_occlusion import NlpOcclusion
from . import global_sensitivity_analysis
diff --git a/xplique/attributions/grad_cam_pp.py b/xplique/attributions/grad_cam_pp.py
index 485c0e3b..15ea3d07 100644
--- a/xplique/attributions/grad_cam_pp.py
+++ b/xplique/attributions/grad_cam_pp.py
@@ -40,10 +40,6 @@ class GradCAMPP(GradCAM):
If a string is provided it will look for the layer name.
"""
- # Avoid zero division during procedure. (the value is not important, as if the denominator is
- # zero, then the nominator will also be zero).
- EPSILON = tf.constant(1e-4)
-
@staticmethod
@tf.function
def _compute_weights(feature_maps_gradients: tf.Tensor,
@@ -71,7 +67,11 @@ def _compute_weights(feature_maps_gradients: tf.Tensor,
nominator = feature_maps_gradients_square
denominator = 2.0 * feature_maps_gradients_square + \
feature_maps_gradients_cube * feature_map_avg
- denominator += tf.cast(denominator == 0, tf.float32) * GradCAMPP.EPSILON
+
+ # Avoid zero division during procedure. (the value is not important, as if the denominator is
+ # zero, then the nominator will also be zero).
+ EPSILON = tf.constant(1e-4)
+ denominator += tf.cast(denominator == 0, tf.float32) * EPSILON
feature_map_alphas = nominator / denominator * tf.nn.relu(feature_maps_gradients)
weights = tf.reduce_mean(feature_map_alphas, axis=(1, 2))
diff --git a/xplique/attributions/nlp_occlusion.py b/xplique/attributions/nlp_occlusion.py
new file mode 100644
index 00000000..29261459
--- /dev/null
+++ b/xplique/attributions/nlp_occlusion.py
@@ -0,0 +1,106 @@
+"""
+Module related to Occlusion sensitivity method for NLP.
+"""
+
+import numpy as np
+
+from .base import BlackBoxExplainer
+from ..commons import Tasks
+from ..types import Callable, Union, Optional, OperatorSignature, List
+
+class NlpOcclusion(BlackBoxExplainer):
+ """
+ Occlusion class for NLP.
+ """
+ def __init__(self,
+ model: Callable,
+ batch_size: Optional[int] = 32,
+ operator: Optional[Union[Tasks, str, OperatorSignature]] = None):
+ super().__init__(model, batch_size, operator)
+
+ @staticmethod
+ def _get_masks(input_len: int) -> np.ndarray:
+ """
+ Generate occlusion masks for a given input length.
+
+ Parameters
+ ----------
+ input_len : int
+ The length of the input for which occlusion masks are generated.
+ Typically it will be the number of words of a sentence.
+
+ Returns
+ -------
+ occlusion_masks : np.ndarray
+ The boolean occlusion masks, an identity matrix with False for the main diagonal.
+ This kind of mask can be used to generate n sentences,
+ each with a single distinct word removed.
+ """
+ return np.eye(input_len) == 0
+
+ @staticmethod
+ def _apply_masks(words: List[str], masks: np.ndarray) -> np.ndarray:
+ """
+ Apply occlusion masks to a list of words.
+
+ Parameters
+ ----------
+ words : List[str]
+ The list of words to which occlusion masks are applied.
+ masks : np.ndarray
+ The boolean occlusion masks to be applied.
+
+ Returns
+ -------
+ occluded_words : np.ndarray
+ The list of words with occlusion masks applied.
+ """
+ perturbated_words = [np.array(words)[mask].tolist() for mask in masks]
+ return perturbated_words
+
+ def explain(self,
+ sentence: str,
+ words: List[str],
+ separator: str) -> np.ndarray:
+ """
+ Generate an explanation for the input sentence, by providing the importance of each word.
+ The importance will be computed by successively occluding each word of the sentence and
+ studying the impact of this occlusion on the model results.
+
+ Parameters
+ ----------
+ sentence : str
+ The input sentence for which an explanation is generated.
+ words : List[str]
+ List of words used to generate the explanation. These words must be part of
+ the input sentence, the importance will be computed on this list of words
+ (i.e some words of the original sentence can be omited this way).
+ separator : str
+ The separator used to join the words after the occlusion step, so a full
+ sentence can be fed to the model.
+
+ Returns
+ -------
+ explanation : np.ndarray
+ The generated explanation of format (nb_concepts, nb_words).
+ """
+
+ # generate n sentences with a different word masked (removed) each time
+ masks = NlpOcclusion._get_masks(len(words))
+ perturbated_words = NlpOcclusion._apply_masks(words, masks)
+
+ perturbated_sentences = [sentence]
+ perturbated_sentences.extend(
+ [separator.join(perturbated_word) for perturbated_word in perturbated_words])
+
+ # transform the perturbated reviews into their concept representation
+ # u_values has shape: ((W+1) x C)
+ u_values = self.model(perturbated_sentences)
+
+ # Compute sensitivities: importances = u_value of the whole sentence - u_value of each word
+ whole_sentence_uvalues = u_values[0,:]
+ words_uvalues = u_values[1:,:]
+ l_importances = (whole_sentence_uvalues - words_uvalues).transpose()
+ l_importances /= (np.max(np.abs(l_importances)) + 1e-5)
+
+ return l_importances
diff --git a/xplique/attributions/rise.py b/xplique/attributions/rise.py
index 64afaa8a..424bd05b 100644
--- a/xplique/attributions/rise.py
+++ b/xplique/attributions/rise.py
@@ -42,10 +42,6 @@ class Rise(BlackBoxExplainer):
Value used as when applying masks.
"""
- # Avoid zero division during procedure. (the value is not important, as if the denominator is
- # zero, then the nominator will also be zero).
- EPSILON = tf.constant(1e-4)
-
def __init__(self,
model: Callable,
batch_size: Optional[int] = 32,
@@ -113,7 +109,10 @@ def explain(self,
rise_nominator += tf.reduce_sum(predictions * masks_upsampled, 0)
rise_denominator += tf.reduce_sum(masks_upsampled, 0)
- rise_map = rise_nominator / (rise_denominator + Rise.EPSILON)
+ # Avoid zero division during procedure. (the value is not important, as if the denominator is
+ # zero, then the nominator will also be zero).
+ EPSILON = tf.constant(1e-4)
+ rise_map = rise_nominator / (rise_denominator + EPSILON)
rise_map = rise_map[tf.newaxis]
rise_maps = rise_map if rise_maps is None else tf.concat([rise_maps, rise_map], axis=0)
diff --git a/xplique/commons/__init__.py b/xplique/commons/__init__.py
index 94237f90..cea752a4 100644
--- a/xplique/commons/__init__.py
+++ b/xplique/commons/__init__.py
@@ -4,10 +4,16 @@
from .data_conversion import tensor_sanitize, numpy_sanitize
from .model_override import guided_relu_policy, deconv_relu_policy, override_relu_gradient, \
- find_layer, open_relu_policy
+ find_layer, open_relu_policy
from .tf_operations import repeat_labels, batch_tensor
from .callable_operations import predictions_one_hot_callable
from .operators_operations import (Tasks, get_operator, check_operator, operator_batching,
get_inference_function, get_gradient_functions)
from .exceptions import no_gradients_available, raise_invalid_operator
from .forgrad import forgrad
+try:
+ from .nlp import (TokenExtractor, WordExtractor, SentenceExtractor, FlairClauseExtractor,
+ ExcerptExtractor, ExtractorFactory)
+ from .torch_operations import batcher, nlp_batch_predict, NlpPreprocessor
+except ModuleNotFoundError:
+ pass
diff --git a/xplique/commons/nlp.py b/xplique/commons/nlp.py
new file mode 100644
index 00000000..9d21e56e
--- /dev/null
+++ b/xplique/commons/nlp.py
@@ -0,0 +1,288 @@
+"""
+NLP Util Classes
+"""
+
+from abc import ABC, abstractmethod
+import re
+from typing import Type
+
+from nltk.tokenize import word_tokenize, sent_tokenize
+
+from flair.models import SequenceTagger
+from flair.data import Sentence
+
+from spacy import Language
+
+from ..types import Union, List, Tuple
+
+
+class TokenExtractor(ABC):
+ """
+ Base class for the token extractors.
+
+ Parameters
+ ----------
+ language
+ The language of the sentence.
+ """
+
+ def __init__(self, language: str = "english"):
+ self.separator = " "
+ self.language = language
+
+ @abstractmethod
+ def extract_tokens(self, sentence: Union[List[str], str]) -> List[str]:
+ """
+ Extract tokens from a sentence.
+
+ Parameters
+ ----------
+ sentence
+ The input sentence, or a list of input sentences.
+
+ Returns
+ -------
+ tokens
+ A list of extracted tokens.
+ """
+
+
+class WordExtractor(TokenExtractor):
+ """
+ Uses NLTK word tokenizer.
+ """
+
+ def extract_tokens(self, sentence: Union[List[str], str]) -> List[str]:
+ """
+ Extract tokens from a sentence, using the NLTK word tokenizer
+
+ Parameters
+ ----------
+ sentence
+ The input sentence, or a list of input sentences.
+
+ Returns
+ -------
+ tokens
+ A list of extracted words.
+ """
+ if not isinstance(sentence, str):
+ sentence = '.'.join(sentence)
+ words = word_tokenize(sentence)
+
+ return words, self.separator
+
+
+class SentenceExtractor(TokenExtractor):
+ """
+ Uses NLTK sentence tokenizer.
+
+ Parameters
+ ----------
+ language
+ The language of the sentence.
+ """
+
+ def __init__(self, language: str = "english"):
+ super().__init__(language=language)
+ self.separator = ". "
+
+ def extract_tokens(self, sentence: Union[List[str], str]) -> List[str]:
+ """
+ Extract tokens from a sentence, using the NLTK sentence tokenizer.
+
+ Parameters
+ ----------
+ sentence
+ The input sentence, or a list of input sentences.
+
+ Returns
+ -------
+ tokens
+ A list of extracted sentences.
+ """
+ if not isinstance(sentence, str):
+ sentence = '.'.join(sentence)
+ words = sent_tokenize(sentence, self.language)
+ return words, self.separator
+
+
+class ExcerptExtractor(TokenExtractor):
+ """
+ Uses a custom excerpt tokenizer.
+ """
+
+ def extract_tokens(self, sentence: Union[List[str], str]) -> List[str]:
+ """
+ Extract excerpts from a sentence:
+ - split the input string with '.', '?', and '!' separators
+ - ignore excerpts not starting with a capital letter
+
+ Parameters
+ ----------
+ sentence
+ The input sentence, or a list of input sentences.
+
+ Returns
+ -------
+ tokens
+ A list of extracted excerpts.
+ """
+ if not isinstance(sentence, str):
+ sentence = '.'.join(sentence)
+ # Split the input text with '.', '?' and '!' separators.
+ regexp = r"[A-Z][^.?!]*[.?!]"
+ res = re.findall(regexp, sentence)
+ excerpt_dataset = [st.strip() for st in res]
+ return excerpt_dataset, self.separator
+
+
+class FlairClauseExtractor(TokenExtractor):
+ """
+ Uses Flair sentence tokenizer.
+
+ Parameters
+ ----------
+ tagger
+ A trained Flair SequenceTagger.
+ clause_type
+ A list with the types of clauses to keep. Each clause shall be a string
+ corresponding to one of Flair SequenceTagger dictionnary.
+ Example: clause_type = ['NP', 'ADJP']
+ If None, all clauses are kept.
+ language
+ The language of the sentence.
+ """
+
+ def __init__(self,
+ tagger: Type[SequenceTagger],
+ clause_type: List[str] = None,
+ language: str = "english"):
+ super().__init__(language=language)
+ self.clause_type = clause_type
+ self.tagger = tagger
+
+ def extract_tokens(self, sentence: Union[List[str], str]) -> Tuple[List[str], str]:
+ """
+ Separates the input texts into clauses, and only keeps the ones belonging to
+ the specified types.
+ If clause_type is None, the texts are split but all the clauses are kept.
+
+ Parameters
+ ----------
+ sentence
+ A list of strings that we wish to separate into clauses.
+
+ Returns
+ -------
+ clause_list
+ A list with input texts split into clauses.
+ separator
+ The separator used for spliting the sentence.
+ """
+ if not isinstance(sentence, str):
+ sentence = '.'.join(sentence)
+ sentence = Sentence(sentence)
+ self.tagger.predict(sentences=sentence)
+ clause_list = []
+ for segment in sentence.get_labels():
+ if self.clause_type is None:
+ clause_list.append(segment.data_point.text)
+ elif segment.value in self.clause_type:
+ clause_list.append(segment.data_point.text)
+
+ return clause_list, self.separator
+
+
+class SpacyClauseExtractor(TokenExtractor):
+ """
+ Uses Spacy sentence tokenizer.
+
+ Parameters
+ ----------
+ pipeline
+ A trained Spacy pipeline.
+ clause_type
+ A list with the types of clauses to keep. Each clause shall be a string
+ corresponding to one of Spacy Tag dictionnary.
+ Example: clause_type = ['NNP', 'VBZ']
+ If None, all clauses are kept.
+ """
+
+ def __init__(self,
+ pipeline: Type[Language],
+ clause_type: List[str] = None):
+ super().__init__(language=None)
+ self.clause_type = clause_type
+ self.pipeline = pipeline
+
+ def extract_tokens(self, sentence: Union[List[str], str]) -> Tuple[List[str], str]:
+ """
+ Separates the input texts into clauses, and only keeps the ones belonging to
+ the specified types.
+ If clause_type is None, the texts are split but all the clauses are kept.
+
+ Parameters
+ ----------
+ sentence
+ A list of strings that we wish to separate into clauses.
+
+ Returns
+ -------
+ clause_list
+ A list with input texts split into clauses.
+ separator
+ The separator used for spliting the sentence.
+ """
+ if not isinstance(sentence, str):
+ sentence = '.'.join(sentence)
+
+ clause_list = []
+ for token in self.pipeline(sentence):
+ if self.clause_type is None:
+ clause_list.append(token.text)
+ elif token.tag_ in self.clause_type:
+ clause_list.append(token.text)
+
+ return clause_list, self.separator
+
+
+class ExtractorFactory():
+ """
+ Factory for extractor classes.
+ """
+ @staticmethod
+ def get_extractor(extract_fct="sentence", tagger=None, pipeline=None, clause_type=None):
+ """
+ Get an instance of an extractor based on the specified extraction function.
+
+ Parameters
+ ----------
+ extract_fct
+ The type of extraction function to use, either 'word', 'sentence', 'excerpt',
+ 'flair_clause', 'spacy_clause'.
+ Default is "sentence".
+ tagger
+ A Flair SequenceTagger to use for Flair clause extractor.
+ pipeline
+ A Spacy Pipeline to use for Spacy clause extractor.
+ clause_type
+ Additional parameter specifying the type of clause (default is None).
+
+ Returns
+ -------
+ Extractor
+ An instance of the selected extractor based on the specified function.
+ """
+ if extract_fct == "flair_clause":
+ return FlairClauseExtractor(tagger, clause_type)
+ if extract_fct == "spacy_clause":
+ return SpacyClauseExtractor(pipeline, clause_type)
+ if extract_fct == "sentence":
+ return SentenceExtractor()
+ if extract_fct == "word":
+ return WordExtractor()
+ if extract_fct == "excerpt":
+ return ExcerptExtractor()
+ raise ValueError("Extraction function can be only 'clause', \
+ 'sentence', 'word' or 'excerpt'")
diff --git a/xplique/commons/operators.py b/xplique/commons/operators.py
index 84217df9..0c10fcd0 100644
--- a/xplique/commons/operators.py
+++ b/xplique/commons/operators.py
@@ -7,7 +7,7 @@
import tensorflow as tf
from ..types import Callable, Optional
-from ..utils_functions.object_detection import _box_iou, _format_objects, _EPSILON
+from ..utils_functions.object_detection import _box_iou, _format_objects
@tf.function
@@ -197,6 +197,7 @@ def batch_loop(args):
class_refs = tf.repeat(tf.expand_dims(class_refs, axis=1), repeats=size, axis=1)
# (nb_box_ref, nb_box_pred)
+ _EPSILON = tf.constant(1e-4)
classification_score = tf.reduce_sum(class_refs * classification, axis=-1) \
/ (tf.norm(classification, axis=-1) * tf.norm(class_refs, axis=-1)+ _EPSILON)
diff --git a/xplique/commons/torch_operations.py b/xplique/commons/torch_operations.py
new file mode 100644
index 00000000..2dd0c431
--- /dev/null
+++ b/xplique/commons/torch_operations.py
@@ -0,0 +1,150 @@
+"""
+Custom pytorch operations
+"""
+
+from abc import ABC, abstractmethod
+from math import ceil
+import numpy as np
+import torch
+
+from ..types import Union, Tuple, Callable, List, Dict
+
+
+class NlpPreprocessor(ABC):
+ """
+ Abstract base class for NLP preprocessing.
+
+ Parameters
+ ----------
+ tokenizer
+ The tokenizer function to be used for tokenization. It must be
+ a callable that returns outputs matching the model inputs expectations.
+ device
+ The device on which to perform tokenization (default is 'cuda').
+ kwargs
+ A list of key value arguments that will be passed to the tokenizer when
+ tokenizing inputs.
+ """
+
+ def __init__(self, tokenizer: Callable, device='cuda', **kwargs):
+ self.tokenizer = tokenizer
+ self.device = device
+ self.tokenizer_kwargs = kwargs
+
+ def tokenize(self,
+ samples: Union[List[str], np.ndarray]) -> Dict[str, torch.Tensor]:
+ """
+ Transforms a list of strings into tokens for consumption by the transformer model.
+
+ Parameters
+ ----------
+ samples
+ A list of input strings to be tokenized.
+
+ Returns
+ -------
+ tokenized
+ Result of the tokenizer operation over the input sequences. Usually a dictionary
+ with keys 'input_ids' and 'attention_mask'.
+ """
+ tokenized = self.tokenizer(
+ list(samples),
+ **self.tokenizer_kwargs
+ ).to(self.device)
+
+ return tokenized
+
+ @abstractmethod
+ def preprocess(self,
+ inputs: List[str],
+ labels: List[str]) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Pre-process the dataset and adapt it according to a specific task.
+
+ Parameters
+ ----------
+ inputs
+ The input text data.
+ labels
+ The corresponding labels.
+
+ Returns:
+ ----------
+ preprocessed_inputs
+ The pre-processed inputs.
+ preprocessed_labels
+ The pre-processed outputs.
+ """
+ raise NotImplementedError
+
+
+def batcher(elements, batch_size: int):
+ """
+ An function to create batches from a list of elements.
+
+ Parameters
+ ----------
+ elements
+ The list of elements to be batched.
+ batch_size
+ The size of each batch.
+
+ Returns
+ ------
+ batch
+ A batch of elements (yielded).
+ """
+ nb_batchs = ceil(len(elements) / batch_size)
+
+ for batch_i in range(nb_batchs):
+ batch_start = batch_i * batch_size
+ batch_end = batch_start + batch_size
+
+ batch = elements[batch_start:batch_end]
+ yield batch
+
+
+def nlp_batch_predict(model: Union[torch.nn.Module, Callable],
+ preprocessor: NlpPreprocessor,
+ inputs: List[str],
+ labels: List[str],
+ batch_size: int = 64) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Pre-processes and predicts using the transformer model in batches.
+
+ Parameters
+ ----------
+ model
+ The transformer model for prediction.
+ preprocessor
+ An instance of NlpPreprocessor used for preprocessing the input texts.
+ inputs
+ A list of n_samples input texts to be predicted.
+ labels
+ A list of labels corresponding to the input texts.
+ batch_size
+ The batch size (default is 64).
+
+ Returns
+ -------
+ predictions
+ The predictions output of the model for the pre-processed input texts.
+ processed_labels
+ The pre-processed labels.
+ """
+ predictions = None
+ processed_labels = None
+
+ with torch.no_grad():
+ dataset = list(zip(inputs, labels))
+ for batch in batcher(dataset, batch_size):
+ batch_inputs, batch_labels = zip(*batch)
+ x_preprocessed, y_preprocessed = preprocessor.preprocess(
+ np.array(batch_inputs), np.array(batch_labels))
+ out_batch = model(**x_preprocessed)
+ predictions = out_batch if predictions is None else torch.cat(
+ [predictions, out_batch])
+ processed_labels = y_preprocessed if processed_labels is None else torch.cat([
+ processed_labels, y_preprocessed])
+
+ return predictions, processed_labels
diff --git a/xplique/concepts/__init__.py b/xplique/concepts/__init__.py
index 67a100a6..46d7b23a 100644
--- a/xplique/concepts/__init__.py
+++ b/xplique/concepts/__init__.py
@@ -8,5 +8,7 @@
from .craft_tf import CraftTf, CraftManagerTf
try:
from .craft_torch import CraftTorch, CraftManagerTorch
+ from .cockatiel import CockatielTorch
+ from .cockatiel_manager import CockatielManagerTorch
except ImportError:
pass
diff --git a/xplique/concepts/cockatiel.py b/xplique/concepts/cockatiel.py
new file mode 100644
index 00000000..598ab6fe
--- /dev/null
+++ b/xplique/concepts/cockatiel.py
@@ -0,0 +1,569 @@
+"""
+COCKATIEL Module for Pytorch
+"""
+
+from typing import Callable, List, Tuple, Optional, Union, Type
+from math import ceil
+
+import html
+from IPython.display import display, HTML
+
+import numpy as np
+import torch
+
+from xplique.attributions.base import BlackBoxExplainer
+from ..commons import TokenExtractor, WordExtractor, ExcerptExtractor, NlpPreprocessor
+from .craft import MaskSampler
+from .craft_torch import BaseCraftTorch
+
+
+class BaseCockatiel(BaseCraftTorch):
+ """
+ A class implementing COCKATIEL, the concept based explainability method for NLP introduced
+ in https://arxiv.org/abs/2305.06754
+
+ Parameters
+ ----------
+ input_to_latent_model
+ The first part of the model taking an input and returning
+ positive activations, g(.) in the original paper.
+ Must be a Pytorch model accepting tokenized inputs.
+ latent_to_logit_model
+ The second part of the model taking activation and returning
+ logits, h(.) in the original paper.
+ Must be a Pytorch model.
+ preprocessor
+ A callable object to transform strings into inputs for the model.
+ Embeds a tokenizer that will be used to feed the inputs to the model.
+ number_of_concepts
+ The number of concepts to extract. Default is 25.
+ batch_size
+ The batch size for all the operations that use the model. Default is 256.
+ device
+ The type of device on which to place the torch tensors
+ """
+
+ def __init__(
+ self,
+ input_to_latent_model: Callable,
+ latent_to_logit_model: Callable,
+ preprocessor: NlpPreprocessor,
+ number_of_concepts: int = 25,
+ batch_size: int = 256,
+ device: str = 'cuda'
+ ):
+ super().__init__(input_to_latent_model, latent_to_logit_model,
+ number_of_concepts, batch_size, device=device)
+ self.preprocessor = preprocessor
+ self.patch_extractor = ExcerptExtractor()
+
+ def _latent_predict(self, inputs: List[str], resize=None) -> torch.Tensor:
+ """
+ Compute the embedding space using the 1st model `input_to_latent_model`.
+
+ Parameters
+ ----------
+ inputs
+ A list of input string data.
+ resize
+ Unused parameter.
+
+ Returns
+ -------
+ activations
+ The latent activations of shape (n_samples, channels)
+ """
+ def gen_activations():
+ with torch.no_grad():
+ nb_batchs = ceil(len(inputs) / self.batch_size)
+ for batch_id in range(nb_batchs):
+ batch_start = batch_id * self.batch_size
+ batch_end = batch_start + self.batch_size
+ batch_tokenized = self.preprocessor.tokenize(
+ samples=inputs[batch_start:batch_end])
+
+ batch_activations = self.input_to_latent_model(
+ **batch_tokenized)
+ yield batch_activations
+
+ activations = torch.cat(list(gen_activations()), 0)
+ return activations
+
+ def _preprocess(self, activations: torch.Tensor) -> np.ndarray:
+ """
+ Preprocesses the activations to make sure that they're the right shape
+ for being input to the NMF algorithm later.
+
+ Parameters
+ ----------
+ activations
+ The (non-negative) activations from the model under study.
+
+ Returns
+ -------
+ activations
+ The preprocessed activations, ready for COCKATIEL.
+ """
+
+ # pylint disable=no-member
+ assert torch.min(activations) >= 0.0, "Activations must be positive."
+
+ if len(activations.shape) == 4:
+ activations = torch.mean(activations, dim=(1, 2))
+
+ return self._to_np_array(activations)
+
+ def _extract_patches(self, inputs: List[str]) -> Tuple[List[str], np.ndarray]:
+ """
+ Extract patches (excerpts) from the input sentences, and compute their embeddings.
+
+ Parameters
+ ----------
+ inputs
+ Input sentences.
+
+ Returns
+ -------
+ patches
+ A list of excerpts (n_patches).
+ activations
+ The excerpts activations (n_patches, channels).
+ """
+ crops, _ = self.patch_extractor.extract_tokens(inputs)
+
+ activations = self._latent_predict(crops)
+
+ # applying GAP(.) on the activation and ensure positivity if needed
+ activations = self._preprocess(activations)
+
+ return crops, activations
+
+ def estimate_importance(self,
+ inputs: np.ndarray = None,
+ sampler: MaskSampler = MaskSampler.SOBOL,
+ nb_design: int = 32,
+ cmaps: Optional[Union[Tuple, str]] = None) -> np.ndarray:
+ """
+ Estimates the importance of each concept for a given class, either globally
+ on the whole dataset provided in the fit() method (in this case, inputs shall
+ be set to None), or locally on a specific input sentence.
+
+ Parameters
+ ----------
+ inputs : numpy array or Tensor
+ The input data on which to compute the importances.
+ If None, then the inputs provided in the fit() method
+ will be used (global importance of the whole dataset).
+ Default is None.
+ sampler
+ The sampling method to use for masking. Defaults to MaskSampler.SOBOL.
+ nb_design
+ The number of design to use for the importance estimation. Default is 32.
+ cmaps
+ The list of colors associated with each concept.
+ Can be either:
+ - A list of (r, g, b) colors to use as a base for the colormap.
+ - A colormap name compatible with `plt.get_cmap(cmap)`.
+
+ Returns
+ -------
+ importances
+ The Sobol total index (importance score) for each concept.
+
+ """
+ return super().estimate_importance(inputs=inputs, sampler=sampler,
+ nb_design=nb_design, cmaps=cmaps)
+
+ def get_best_excerpts_per_concept(self,
+ nb_excerpts: int = 10,
+ nb_most_important_concepts: int = None) -> List[str]:
+ """
+ Return the best excerpts for each concept.
+
+ Parameters
+ ----------
+ nb_excerpts
+ The number of excerpts (patches) to fetch per concept. Defaults to 10.
+ nb_most_important_concepts
+ The number of concepts to display. If provided, only display
+ nb_most_important_concepts, otherwise display them all.
+ Default is None.
+
+ Returns
+ -------
+ best_excerpts_per_concept
+ The list of the best excerpts per concept
+ """
+ best_excerpts_per_concept = []
+ for _, _, best_crops in \
+ self._gen_best_concepts_crops(nb_excerpts, nb_most_important_concepts):
+ best_excerpts_per_concept.append(best_crops)
+ return best_excerpts_per_concept
+
+
+class CockatielNlpVisualizationMixin():
+ """
+ Class containing text visualization methods for Cockatiel.
+ """
+
+ def display_concepts_excerpts(self,
+ nb_patchs: int = 10,
+ nb_most_important_concepts: int = None) -> None:
+ """
+ Display the best excerpts for each concept.
+
+ Parameters
+ ----------
+ nb_patchs
+ The number of patches to display per concept. Defaults to 10.
+ nb_most_important_concepts
+ The number of concepts to display. If provided, only display
+ nb_most_important_concepts, otherwise display them all.
+ Default is None.
+ """
+ for c_id, c_id_importance, best_crops in \
+ self._gen_best_concepts_crops(nb_patchs, nb_most_important_concepts):
+
+ print(f"Concept {c_id} has an importance value of "
+ f"{c_id_importance:.2f}")
+ for crop in best_crops:
+ print(f"\t{crop}")
+
+ def plot_concept_attribution_maps(self,
+ sentences: List[str],
+ token_extractor: TokenExtractor,
+ explainer_class: Type[BlackBoxExplainer],
+ ignore_words: List[str] = None,
+ importances: np.ndarray = None,
+ nb_most_important_concepts: int = 5,
+ title: str = "",
+ css_namespace: str = "",
+ filter_percentile: int = 80,
+ display_width: int = 400) -> List[str]:
+ """
+ Display the concepts attribution maps for the sentences given in argument.
+
+ Parameters
+ ----------
+ sentences
+ The list of sentences for which attribution maps will be displayed.
+ token_extractor
+ The token extractor used to extract tokens from the sentences.
+ explainer_class
+ Explainer class used during the masking process. Typically a NlpOcclusion class.
+ ignore_words
+ Words to ignore during the occlusion process. These will not be part of the
+ extracted concepts.
+ importances
+ The importances computed by the estimate_importance() method.
+ If None is provided, then the global importances will be used, otherwise
+ the local importances set in this parameter will be used.
+ nb_most_important_concepts
+ Number of most important concepts to consider. Default is 5.
+ title
+ The title to use when displaying the results.
+ css_namespace
+ A namespace to be prefixed to the css class in order to avoid collisions.
+ Usefull when running several instances of Cockatiel, to prevent having the
+ same color for all the concepts of the same name.
+ filter_percentile
+ Percentile used to filter the concept heatmap when displaying the HTML
+ sentences (only show concept if excess N-th percentile). Default is 80.
+ display_width
+ Width of the displayed HTML text. Default is 400.
+
+ Returns
+ -------
+ html_lines
+ List of HTML text lines.
+ """
+ if importances is None:
+ # global
+ most_important_concepts = self.sensitivity.most_important_concepts
+ else:
+ # local
+ most_important_concepts = np.argsort(importances)[::-1]
+
+ most_important_concepts = most_important_concepts[:nb_most_important_concepts]
+
+ # Build the CSS
+ concepts_css_names = [
+ f'{css_namespace}_{concept_id}' for concept_id in most_important_concepts]
+
+ css = CockatielNlpVisualizationMixin.build_concepts_css(
+ concepts_names=concepts_css_names,
+ cmaps=self.sensitivity.cmaps[:nb_most_important_concepts])
+ html_output = [css]
+
+ # Build the title
+ html_output.append(f"{title}")
+
+ # Build the legend
+ html_output.append(
+ f"
Legend :
"
+ f"{CockatielNlpVisualizationMixin.get_legend(concepts_css_names)}
")
+
+ # Generate the HTML for each sentence
+ html_output.append('') # list start
+ for sentence in sentences:
+ html_output.append('- ')
+ html_output.append(
+ self.sentence_attribution_map(sentence=str(sentence),
+ token_extractor=token_extractor,
+ ignore_words=ignore_words,
+ explainer_class=explainer_class,
+ most_important_concepts=most_important_concepts,
+ concepts_names=concepts_css_names,
+ filter_percentile=filter_percentile,
+ display_width=display_width))
+ html_output.append('
')
+
+ html_output.append('
') # list end
+
+ # Display the HTML text & return it as well
+ display(HTML("".join(html_output)))
+ return html_output
+
+ def sentence_attribution_map(self,
+ sentence: str,
+ token_extractor: WordExtractor,
+ explainer_class: Type[BlackBoxExplainer],
+ most_important_concepts: np.ndarray,
+ concepts_names: List[str] = None,
+ ignore_words: List[str] = None,
+ filter_percentile: int = 80,
+ display_width: int = 400) -> str:
+ """
+ Compute the concepts attribution maps for the sentence given in argument,
+ and return the corresponding sentence as an HTML formatted sentence where
+ the words belonging to each concept are highlighted with colors.
+
+ Parameters
+ ----------
+ sentence
+ The sentence to process.
+ token_extractor
+ The token extractor used to extract tokens from the sentences.
+ explainer_class
+ Explainer class used during the masking process. Typically a NlpOcclusion class.
+ most_important_concepts
+ The concepts ids to display.
+ concepts_names
+ The concepts names to use for the HTML display, corresponding to
+ most_important_concepts.
+ ignore_words
+ Words to ignore during the occlusion process. These will not be part of the
+ extracted concepts.
+ filter_percentile
+ Percentile used to filter the concept heatmap when displaying the HTML
+ sentences (only show concept if excess N-th percentile). Default is 80.
+ display_width
+ Width of the displayed HTML text. Default is 400.
+
+ Returns
+ -------
+ html_lines
+ The sentence formatted as HTML text.
+ """
+ extracted_words, separator = token_extractor.extract_tokens(sentence)
+
+ # Filter words
+ if ignore_words is not None:
+ words = [word for word in extracted_words if word not in ignore_words]
+ else:
+ words = extracted_words
+
+ explainer = explainer_class(model=self.transform)
+ l_importances = explainer.explain(
+ sentence=sentence, words=words, separator=separator)
+
+ # Display only the most important concepts
+ l_importances = l_importances[most_important_concepts]
+
+ return self.convert_sentence_to_html(extracted_words, separator, words, l_importances,
+ concepts_names=concepts_names,
+ filter_percentile=filter_percentile,
+ display_width=display_width)
+
+ @staticmethod
+ def get_legend(concepts_names: List[str]) -> str:
+ """
+ Generates an HTML legend for the concepts attribution maps.
+
+ Parameters
+ ----------
+ concepts_names
+ The list of concepts names that are handled by the current Cockatiel instance.
+
+ Returns
+ -------
+ html_legend
+ The legend formatted as HTML text.
+ """
+
+ # Extract the label name from the concept name, which are composed of [css_namespace]_[cid]
+ labels = [
+ f"Concept {c_name.split('_')[-1]}" for c_name in concepts_names]
+ legend_importances = np.eye(len(concepts_names)) / 2.0
+ return CockatielNlpVisualizationMixin.\
+ convert_sentence_to_html(extracted_words=labels,
+ words=labels,
+ separator=" ",
+ concepts_names=concepts_names,
+ explanation=legend_importances)
+
+ @staticmethod
+ def convert_sentence_to_html(
+ extracted_words: List[str],
+ separator: str,
+ words: List[str],
+ explanation: np.ndarray,
+ concepts_names: List[str],
+ filter_percentile: int = 80,
+ display_width: int = 400
+ ) -> str:
+ """
+ Generates the visualization for COCKATIEL's explanations.
+
+ Parameters
+ ----------
+ extracted_words
+ List of extracted words from the input sentence: it shall
+ be possible to reconstruct the whole input sentence using `extracted_words`
+ and `separator`.
+ separator
+ Separator used to join extracted_words to build the whole sentence.
+ words
+ List of words used to generate the explanation, these words should be part
+ of the sentence and are attached to a concept of the `explanation` parameter below.
+ explanation
+ An array that corresponds to the output of the occlusion function.
+ Has a shape of (nb_concepts x nb_words).
+ concepts_names
+ The list of concepts names that are handled by the current Cockatiel instance.
+ filter_percentile
+ Percentile used to filter the concept heatmap when displaying the HTML
+ sentences (only show concept if excess N-th percentile). Default is 80.
+ display_width
+ Width of the displayed HTML text. Default is 400.
+
+ Returns
+ -------
+ html_lines
+ The sentence formatted as HTML text.
+ """
+ l_phi = np.array(explanation)
+
+ # Filter the values that are below the percentile for each concept importance array
+ sigmas = [[np.percentile(phi, filter_percentile)] for phi in l_phi]
+ l_phi = l_phi * np.array(l_phi > sigmas, np.float32)
+
+ phi_html = []
+
+ # Build a dictionary of structure 'word' => importance_value, concept_id
+ words_importances = dict(
+ zip(words, zip(l_phi.max(axis=0), l_phi.argmax(axis=0))))
+
+ i = 0
+ for word in extracted_words:
+ if word in words_importances:
+ value, value_max_id = words_importances[word]
+
+ # concept name is composed of css_namespace and c_id
+ concept_name = concepts_names[value_max_id]
+ c_id = concept_name.split('_')[-1]
+
+ title = f"[{html.escape(word)}] concept:{c_id} importance:{value:.2f}"
+
+ # add debug infos in the popup
+ # indices = np.nonzero(l_phi[:, i])[0]
+ # title += "\n DEBUG:"
+ # for j in indices:
+ # concept_name = concepts_names[j]
+ # title += f"\n concept:{concept_name} importance {l_phi[j,i]:.2f}"
+
+ if value > 0:
+ phi_html.append(f"{word}{separator}")
+ else:
+ phi_html.append(
+ f"{word}{separator}")
+
+ i += 1
+ else:
+ # ignored words
+ phi_html.append(f" {word}{separator}")
+
+ html_result = f"" \
+ + " ".join(phi_html) + ""
+ return html_result
+
+ @staticmethod
+ def build_concepts_css(concepts_names: List[str], cmaps: List[str]):
+ """
+ Generates the CSS for HTML COCKATIEL's explanations.
+
+ Parameters
+ ----------
+ concepts_names
+ The list of concepts names.
+ cmaps
+ The color associated with each concept.
+ Can be either:
+ - A tuple (r, g, b) of color to use as a base for the colormap.
+ - A colormap name compatible with `plt.get_cmap(cmap)`.
+ """
+ legend_css = """
+
+ """
+
+ concepts_css = []
+ for c_id, cmap in zip(concepts_names, cmaps):
+ red, green, blue, alpha = [int(c*255) for c in cmap(1)]
+ txt = f"\n.concept{c_id} {{" \
+ f" padding: 1px 5px;" \
+ f" border: solid 3px;" \
+ f" border-color: rgba({red}, {green}, {blue});" \
+ f" border-radius: 10px;" \
+ f" background-color: rgba({red}, {green}, {blue}, var(--opacity));" \
+ f" --opacity: {alpha};" \
+ f"}}"
+
+ concepts_css.append(txt)
+ css_text = ""
+
+ return legend_css + css_text
+
+
+class CockatielTorch(BaseCockatiel, CockatielNlpVisualizationMixin):
+ """
+ Class implementing Cockatiel on Pytorch, adapted for text visualization.
+
+ Parameters
+ ----------
+ input_to_latent_model
+ The first part of the model taking an input and returning
+ positive activations, g(.) in the original paper.
+ Must be a Pytorch model accepting tokenized inputs.
+ latent_to_logit_model
+ The second part of the model taking activation and returning
+ logits, h(.) in the original paper.
+ Must be a Pytorch model.
+ preprocessor
+ A callable object to transform strings into inputs for the model.
+ Embeds a tokenizer that will be used to feed the inputs to the model.
+ number_of_concepts
+ The number of concepts to extract. Default is 25.
+ batch_size
+ The batch size for all the operations that use the model. Default is 256.
+ device
+ The type of device on which to place the torch tensors
+ """
diff --git a/xplique/concepts/cockatiel_manager.py b/xplique/concepts/cockatiel_manager.py
new file mode 100644
index 00000000..d2f8d5ac
--- /dev/null
+++ b/xplique/concepts/cockatiel_manager.py
@@ -0,0 +1,256 @@
+"""
+COCKATIEL MANAGER Module for Pytorch
+"""
+
+from typing import Callable, List, Type
+import numpy as np
+import torch
+
+from xplique.attributions.base import BlackBoxExplainer
+
+from .craft_torch import BaseCraftManagerTorch
+from .cockatiel import CockatielTorch
+from ..commons import TokenExtractor, NlpPreprocessor
+from ..commons.torch_operations import nlp_batch_predict
+
+
+class BaseCockatielManagerTorch(BaseCraftManagerTorch):
+ """
+ Base class implementing the CockatielManager on Pytorch.
+ This manager creates one Cockatiel instance per class to explain.
+
+ Parameters
+ ----------
+ input_to_latent_model
+ The first part of the model taking an input and returning
+ positive activations, g(.) in the original paper.
+ Must return positive activations.
+ latent_to_logit_model
+ The second part of the model taking activation and returning
+ logits, h(.) in the original paper.
+ preprocessor
+ A callable object to transform strings into inputs for the model.
+ inputs
+ Input data: a list of strings.
+ labels
+ Labels of the inputs.
+ list_of_class_of_interest
+ A list of the classes id to explain. The manager will instanciate one
+ CraftTorch object per element of this list.
+ number_of_concepts
+ The number of concepts to extract. Default is 20.
+ batch_size
+ The batch size to use during training and prediction. Default is 64.
+ device
+ The type of device on which to place the torch tensors
+ """
+
+ def __init__(self, input_to_latent_model: Callable,
+ latent_to_logit_model: Callable,
+ preprocessor: NlpPreprocessor,
+ inputs: List[str],
+ labels: List[str],
+ list_of_class_of_interest: list = None,
+ number_of_concepts: int = 20,
+ batch_size: int = 64,
+ device: str = 'cuda'):
+
+ super().__init__(input_to_latent_model=input_to_latent_model,
+ latent_to_logit_model=latent_to_logit_model,
+ inputs=inputs,
+ labels=labels,
+ list_of_class_of_interest=list_of_class_of_interest)
+ self.preprocessor = preprocessor
+ self.batch_size = batch_size
+ self.device = device
+ for class_of_interest in self.list_of_class_of_interest:
+ craft = CockatielTorch(input_to_latent_model=input_to_latent_model,
+ latent_to_logit_model=latent_to_logit_model,
+ preprocessor=preprocessor,
+ number_of_concepts=number_of_concepts,
+ batch_size=batch_size,
+ device=device)
+ self.craft_instances[class_of_interest] = craft
+
+ def compute_predictions(self) -> np.ndarray:
+ """
+ Computes predictions using the input-to-latent and latent-to-logit models,
+ on the input data provided at the creation of the CockatielManager instance.
+
+ Returns
+ -------
+ np.ndarray
+ Predicted labels.
+ """
+ features, _ = nlp_batch_predict(
+ self.input_to_latent_model, preprocessor=self.preprocessor,
+ inputs=self.inputs, labels=self.labels, batch_size=self.batch_size)
+ logits = self.latent_to_logit_model(features)
+ y_preds = np.array(torch.argmax(logits, -1).cpu().detach())
+ return y_preds
+
+ def get_best_excerpts_per_concept(self,
+ class_id: int,
+ nb_excerpts: int = 10,
+ nb_most_important_concepts: int = None) -> List[str]:
+ """
+ Return the best excerpts for each concept.
+
+ Parameters
+ ----------
+ class_id
+ The class to explain.
+ nb_excerpts
+ The number of excerpts (patches) to fetch per concept. Defaults to 10.
+ nb_most_important_concepts
+ The number of concepts to display. If provided, only display
+ nb_most_important_concepts, otherwise display them all.
+ Default is None.
+
+ Returns
+ -------
+ best_excerpts_per_concept
+ The list of the best excerpts per concept
+ """
+ return self.craft_instances[class_id]\
+ .get_best_excerpts_per_concept(nb_excerpts, nb_most_important_concepts)
+
+
+class CockatielManagerNlpVisualizationMixin:
+ """
+ Class containing text visualization methods for CockatielManager.
+ """
+
+ def plot_concepts_importances(self,
+ class_id: int,
+ nb_most_important_concepts: int = 5,
+ verbose: bool = False):
+ """
+ Plot a bar chart displaying the importance value of each concept.
+
+ Parameters
+ ----------
+ class_id
+ The class to explain.
+ nb_most_important_concepts
+ The number of concepts to focus on. Default is 5.
+ verbose
+ If True, then print the importance value of each concept, otherwise no textual
+ output will be printed.
+ """
+ self.craft_instances[class_id]\
+ .plot_concepts_importances(importances=None,
+ nb_most_important_concepts=nb_most_important_concepts,
+ verbose=verbose)
+
+ def display_concepts_excerpts(self,
+ class_id: int,
+ nb_patchs: int = 10,
+ nb_most_important_concepts: int = None) -> None:
+ """
+ Display the best excerpts for each concept.
+
+ Parameters
+ ----------
+ class_id
+ The class to explain.
+ nb_patchs
+ The number of patches to display per concept. Defaults to 10.
+ nb_most_important_concepts
+ The number of concepts to display. If provided, only display
+ nb_most_important_concepts, otherwise display them all.
+ Default is None.
+ """
+ self.craft_instances[class_id]\
+ .display_concepts_excerpts(nb_patchs=nb_patchs,
+ nb_most_important_concepts=nb_most_important_concepts)
+
+ def plot_concept_attribution_maps(self,
+ class_id: int,
+ sentences: List[str],
+ token_extractor: TokenExtractor,
+ explainer_class: Type[BlackBoxExplainer],
+ ignore_words: List[str] = None,
+ importances: np.ndarray = None,
+ nb_most_important_concepts: int = 5,
+ title: str = "",
+ filter_percentile: int = 80,
+ display_width: int = 400) -> List[str]:
+ """
+ Display the concepts attribution maps for the sentences given in argument.
+
+ Parameters
+ ----------
+ class_id
+ The class id to explain.
+ sentences
+ The list of sentences for which attribution maps will be displayed.
+ token_extractor
+ The token extractor used to extract tokens from the sentences.
+ explainer_class
+ Explainer class used during the masking process. Typically a NlpOcclusion class.
+ ignore_words
+ Words to ignore during the occlusion process. These will not be part of the
+ extracted concepts.
+ importances
+ The importances computed by the estimate_importance() method.
+ If None is provided, then the global importances will be used, otherwise
+ the local importances set in this parameter will be used.
+ nb_most_important_concepts
+ Number of most important concepts to consider. Default is 5.
+ title
+ The title to use when displaying the results.
+ filter_percentile
+ Percentile used to filter the concept heatmap when displaying the HTML
+ sentences (only show concept if excess N-th percentile). Default is 80.
+ display_width
+ Width of the displayed HTML text. Default is 400.
+
+ Returns
+ -------
+ html_lines
+ List of HTML text lines.
+ """
+ self.craft_instances[class_id].plot_concept_attribution_maps(
+ sentences=sentences,
+ token_extractor=token_extractor,
+ explainer_class=explainer_class,
+ ignore_words=ignore_words,
+ importances=importances,
+ nb_most_important_concepts=nb_most_important_concepts,
+ title=title,
+ filter_percentile=filter_percentile,
+ display_width=display_width,
+ css_namespace=class_id)
+
+
+class CockatielManagerTorch(BaseCockatielManagerTorch, CockatielManagerNlpVisualizationMixin):
+ """
+ Class implementing CockatielManager on Pytorch, adapted for text visualization.
+ This manager creates one CockatielTorch instance per class to explain.
+
+ Parameters
+ ----------
+ input_to_latent_model
+ The first part of the model taking an input and returning
+ positive activations, g(.) in the original paper.
+ Must return positive activations.
+ latent_to_logit_model
+ The second part of the model taking activation and returning
+ logits, h(.) in the original paper.
+ preprocessor
+ A callable object to transform strings into inputs for the model.
+ inputs
+ Input data: a list of strings.
+ labels
+ Labels of the inputs of shape (n_samples, class_id)
+ list_of_class_of_interest
+ A list of the classes id to explain. The manager will instanciate one
+ CraftTorch object per element of this list.
+ number_of_concepts
+ The number of concepts to extract. Default is 20.
+ batch_size
+ The batch size to use during training and prediction. Default is 64.
+ device
+ The type of device on which to place the torch tensors
+ """
diff --git a/xplique/concepts/craft.py b/xplique/concepts/craft.py
index dbfb5809..415ce152 100644
--- a/xplique/concepts/craft.py
+++ b/xplique/concepts/craft.py
@@ -7,6 +7,7 @@
from enum import Enum
import colorsys
from math import ceil
+import random
import numpy as np
import cv2
@@ -16,12 +17,14 @@
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import gridspec
-from xplique.attributions.global_sensitivity_analysis import (HaltonSequenceRS, JansenEstimator)
+from xplique.attributions.global_sensitivity_analysis import \
+ (HaltonSequenceRS, ScipySobolSequenceRS, LatinHypercubeRS, JansenEstimator)
from xplique.plots.image import _clip_percentile
from ..types import Callable, Tuple, Optional, Union
from .base import BaseConceptExtractor
+
@dataclasses.dataclass
class Factorization:
""" Dataclass handling data produced during the Factorization step."""
@@ -32,6 +35,7 @@ class Factorization:
crops_u: np.ndarray
concept_bank_w: np.ndarray
+
class Sensitivity:
"""
Dataclass handling data produced during the Sobol indices computation.
@@ -53,14 +57,14 @@ class Sensitivity:
"""
def __init__(self, importances: np.ndarray,
- most_important_concepts: np.ndarray,
- cmaps: Optional[Union[list, str]]=None):
+ most_important_concepts: np.ndarray,
+ cmaps: Optional[Union[list, str]] = None):
self.importances = importances
self.most_important_concepts = most_important_concepts
self.set_concept_attribution_cmap(cmaps=cmaps)
@staticmethod
- def _get_alpha_cmap(cmap: Union[Tuple,str]):
+ def _get_alpha_cmap(cmap: Union[Tuple, str]):
"""
Creat a colormap with an alpha channel, out of 3 (r, g, b) values.
This is used in particular by `set_concept_attribution_cmap()` to
@@ -93,12 +97,12 @@ def _get_alpha_cmap(cmap: Union[Tuple,str]):
cmap = LinearSegmentedColormap.from_list("", [colors, cmax])
alpha_cmap = cmap(np.arange(256))
- alpha_cmap[:,-1] = np.linspace(0, 0.85, 256)
+ alpha_cmap[:, -1] = np.linspace(0, 0.85, 256)
alpha_cmap = ListedColormap(alpha_cmap)
return alpha_cmap
- def set_concept_attribution_cmap(self, cmaps: Optional[Union[Tuple, str]]=None):
+ def set_concept_attribution_cmap(self, cmaps: Optional[Union[Tuple, str]] = None):
"""
Set the colormap used for the concepts displayed in the attribution maps.
@@ -112,24 +116,33 @@ def set_concept_attribution_cmap(self, cmaps: Optional[Union[Tuple, str]]=None):
"""
if cmaps is None:
self.cmaps = [
- Sensitivity._get_alpha_cmap((54, 197, 240)),
- Sensitivity._get_alpha_cmap((210, 40, 95)),
- Sensitivity._get_alpha_cmap((236, 178, 46)),
- Sensitivity._get_alpha_cmap((15, 157, 88)),
- Sensitivity._get_alpha_cmap((84, 25, 85)),
- Sensitivity._get_alpha_cmap((55, 35, 235))
- ]
+ Sensitivity._get_alpha_cmap((54, 197, 240)),
+ Sensitivity._get_alpha_cmap((210, 40, 95)),
+ Sensitivity._get_alpha_cmap((236, 178, 46)),
+ Sensitivity._get_alpha_cmap((15, 157, 88)),
+ Sensitivity._get_alpha_cmap((84, 25, 85)),
+ Sensitivity._get_alpha_cmap((55, 35, 235))
+ ]
# Add more colors by default
cmaps_more = [Sensitivity._get_alpha_cmap(cmap)
- for cmap in plt.get_cmap('tab10').colors]
+ for cmap in plt.get_cmap('tab10').colors]
self.cmaps.extend(cmaps_more)
else:
self.cmaps = [Sensitivity._get_alpha_cmap(cmap) for cmap in cmaps]
if len(self.cmaps) < len(self.most_important_concepts):
- raise RuntimeError(f'Not enough colors in cmaps ({len(self.cmaps)}) ' \
- f'compared to the number of important concepts ' \
- '({len(self.most_important_concepts)})')
+ nb_colors_missing = len(
+ self.most_important_concepts) - len(self.cmaps)
+ print(f'Not enough colors in cmaps ({len(self.cmaps)}) '
+ f'compared to the number of important concepts '
+ f'({len(self.most_important_concepts)}). '
+ f'Adding {nb_colors_missing} random colors.')
+
+ for _ in range(nb_colors_missing):
+ random_color = (random.random(),
+ random.random(), random.random())
+ self.cmaps.append(Sensitivity._get_alpha_cmap(random_color))
+
class DisplayImportancesOrder(Enum):
"""
@@ -139,11 +152,21 @@ class DisplayImportancesOrder(Enum):
LOCAL: Order concepts by their Local importance on a single sample
"""
GLOBAL = 0
- LOCAL = 1
+ LOCAL = 1
def __eq__(self, other):
return self.value == other.value
+
+class MaskSampler(Enum):
+ HALTON = HaltonSequenceRS
+ SOBOL = ScipySobolSequenceRS
+ LATIN = LatinHypercubeRS
+
+ def __eq__(self, other):
+ return self.value == other.value
+
+
class BaseCraft(BaseConceptExtractor, ABC):
"""
Base class implementing the CRAFT Concept Extraction Mechanism.
@@ -171,11 +194,11 @@ class BaseCraft(BaseConceptExtractor, ABC):
The size of the patches (crops) to extract from the input data. Default is 64.
"""
- def __init__(self, input_to_latent_model : Callable,
- latent_to_logit_model : Callable,
- number_of_concepts: int = 20,
- batch_size: int = 64,
- patch_size: int = 64):
+ def __init__(self, input_to_latent_model: Callable,
+ latent_to_logit_model: Callable,
+ number_of_concepts: int = 20,
+ batch_size: int = 64,
+ patch_size: int = 64):
super().__init__(number_of_concepts, batch_size)
self.input_to_latent_model = input_to_latent_model
self.latent_to_logit_model = latent_to_logit_model
@@ -185,11 +208,10 @@ def __init__(self, input_to_latent_model : Callable,
self.sensitivity = None
# sanity checks
- assert(hasattr(input_to_latent_model, "__call__")), \
- "input_to_latent_model must be a callable function"
- assert(hasattr(latent_to_logit_model, "__call__")), \
- "latent_to_logit_model must be a callable function"
-
+ assert (hasattr(input_to_latent_model, "__call__")), \
+ "input_to_latent_model must be a callable function"
+ assert (hasattr(latent_to_logit_model, "__call__")), \
+ "latent_to_logit_model must be a callable function"
@abstractmethod
def _latent_predict(self, inputs: np.ndarray):
@@ -216,11 +238,13 @@ def check_if_fitted(self):
If the factorization model has not been fitted to input data.
"""
if self.factorization is None:
- raise NotFittedError("The factorization model has not been fitted to input data yet.")
+ raise NotFittedError(
+ "The factorization model has not been fitted to input data yet.")
def fit(self,
- inputs : np.ndarray,
- class_id: int = 0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ inputs: np.ndarray,
+ class_id: int = 0,
+ alpha_w: float = 1e-2) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Fit the Craft model to the input data.
@@ -231,6 +255,8 @@ def fit(self,
(x1, x2, ..., xn) in the paper.
class_id
The class id of the inputs.
+ alpha_w
+ Constant that multiplies the NMF regularization terms of the concept banck (W in the paper).
Returns
-------
@@ -246,7 +272,7 @@ def fit(self,
crops, activations = self._extract_patches(inputs)
# apply NMF to the activations to obtain matrices U and W
- reducer = NMF(n_components=self.number_of_concepts, alpha_W=1e-2)
+ reducer = NMF(n_components=self.number_of_concepts, alpha_W=alpha_w)
crops_u = reducer.fit_transform(activations)
concept_bank_w = reducer.components_.astype(np.float32)
@@ -255,7 +281,7 @@ def fit(self,
return crops, crops_u, concept_bank_w
- def transform(self, inputs : np.ndarray, activations : np.ndarray = None) -> np.ndarray:
+ def transform(self, inputs: np.ndarray, activations: np.ndarray = None) -> np.ndarray:
"""Transforms the inputs data into its concept representation.
Parameters
@@ -289,10 +315,15 @@ def transform(self, inputs : np.ndarray, activations : np.ndarray = None) -> np.
if is_4d:
# (N * W * H, R) -> (N, W, H, R) with R = nb_concepts
- coeffs_u = np.reshape(coeffs_u, (*original_shape, coeffs_u.shape[-1]))
+ coeffs_u = np.reshape(
+ coeffs_u, (*original_shape, coeffs_u.shape[-1]))
return coeffs_u
- def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -> np.ndarray:
+ def estimate_importance(self,
+ inputs: np.ndarray = None,
+ sampler: MaskSampler = MaskSampler.HALTON,
+ nb_design: int = 32,
+ cmaps: Optional[Union[Tuple, str]] = None) -> np.ndarray:
"""
Estimates the importance of each concept for a given class, either globally
on the whole dataset provided in the fit() method (in this case, inputs shall
@@ -305,8 +336,15 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -
If None, then the inputs provided in the fit() method
will be used (global importance of the whole dataset).
Default is None.
+ sampler
+ The sampling method to use for masking. Default to MaskSampler.HALTON.
nb_design
The number of design to use for the importance estimation. Default is 32.
+ cmaps
+ The list of colors associated with each concept.
+ Can be either:
+ - A list of (r, g, b) colors to use as a base for the colormap.
+ - A colormap name compatible with `plt.get_cmap(cmap)`.
Returns
-------
@@ -323,7 +361,7 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -
coeffs_u = self.transform(inputs)
- masks = HaltonSequenceRS()(self.number_of_concepts, nb_design = nb_design)
+ masks = sampler.value()(self.number_of_concepts, nb_design=nb_design)
estimator = JansenEstimator()
importances = []
@@ -348,9 +386,9 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -
for coeff in coeffs_u:
u_perturbated = coeff[None, :] * masks[:, None, None, :]
a_perturbated = np.reshape(u_perturbated,
- (-1, coeff.shape[-1])) @ self.factorization.concept_bank_w
+ (-1, coeff.shape[-1])) @ self.factorization.concept_bank_w
a_perturbated = np.reshape(a_perturbated,
- (len(masks), coeffs_u.shape[1], coeffs_u.shape[2], -1))
+ (len(masks), coeffs_u.shape[1], coeffs_u.shape[2], -1))
# a_perturbated: (N, H, W, C)
y_pred = self._logit_predict(a_perturbated)
@@ -365,14 +403,15 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -
# Save the results of the computation if working on the whole dataset
if compute_global_importances:
most_important_concepts = np.argsort(importances)[::-1]
- self.sensitivity = Sensitivity(importances, most_important_concepts)
+ self.sensitivity = Sensitivity(
+ importances, most_important_concepts, cmaps=cmaps)
return importances
def plot_concepts_importances(self,
importances: np.ndarray = None,
- display_importance_order: DisplayImportancesOrder = \
- DisplayImportancesOrder.GLOBAL,
+ display_importance_order: DisplayImportancesOrder =
+ DisplayImportancesOrder.GLOBAL,
nb_most_important_concepts: int = None,
verbose: bool = False):
"""
@@ -422,7 +461,8 @@ def plot_concepts_importances(self,
most_important_concepts = most_important_concepts[:nb_most_important_concepts]
# Find the correct color index
- global_color_index_order = np.argsort(self.sensitivity.importances)[::-1]
+ global_color_index_order = np.argsort(
+ self.sensitivity.importances)[::-1]
local_color_index_order = [np.where(global_color_index_order == local_c)[0][0]
for local_c in most_important_concepts]
colors = np.array([colors(1.0)
@@ -441,7 +481,8 @@ def plot_concepts_importances(self,
if verbose:
for c_id in most_important_concepts:
- print(f"Concept {c_id} has an importance value of {importances[c_id]:.2f}")
+ print(
+ f"Concept {c_id} has an importance value of {importances[c_id]:.2f}")
@staticmethod
def _show(img, **kwargs):
@@ -455,6 +496,46 @@ def _show(img, **kwargs):
plt.imshow(img, **kwargs)
plt.axis('off')
+ def _gen_best_concepts_crops(self,
+ nb_crops: int = 10,
+ nb_most_important_concepts: int = None) \
+ -> Tuple[int, float, np.ndarray]:
+ """
+ Generate the best concept crops for each concept.
+
+ Parameters
+ ----------
+ nb_crops : int
+ The number of crops (patches) to display per concept. Defaults to 10.
+ nb_most_important_concepts : int
+ The number of concepts to consider. If provided, only take into account
+ nb_most_important_concepts, otherwise use them all.
+ Default is None.
+ Returns
+ -------
+ Tuple
+ A tuple containing:
+ - The current concept id.
+ - The overall importance score for this concept.
+ - An array containing the best crops for this concept.
+ """
+ most_important_concepts = self.sensitivity.most_important_concepts
+ if nb_most_important_concepts is not None:
+ most_important_concepts = most_important_concepts[:nb_most_important_concepts]
+
+ for c_id in most_important_concepts:
+ best_crops_ids = np.argsort(self.factorization.crops_u[:, c_id])[
+ ::-1][:nb_crops]
+ best_crops = np.array(self.factorization.crops)[best_crops_ids]
+ c_id_importance = self.sensitivity.importances[c_id]
+ yield c_id, c_id_importance, best_crops
+
+
+class CraftImageVisualizationMixin():
+ """
+ Class containing image visualization methods for Craft.
+ """
+
def plot_concepts_crops(self,
nb_crops: int = 10,
nb_most_important_concepts: int = None,
@@ -474,17 +555,11 @@ def plot_concepts_crops(self,
If True, then print the importance value of each concept,
otherwise no textual output will be printed.
"""
- most_important_concepts = self.sensitivity.most_important_concepts
- if nb_most_important_concepts is not None:
- most_important_concepts = most_important_concepts[:nb_most_important_concepts]
-
- for c_id in most_important_concepts:
- best_crops_ids = np.argsort(self.factorization.crops_u[:, c_id])[::-1][:nb_crops]
- best_crops = np.array(self.factorization.crops)[best_crops_ids]
-
+ for c_id, c_id_importance, best_crops in \
+ self._gen_best_concepts_crops(nb_crops, nb_most_important_concepts):
if verbose:
- print(f"Concept {c_id} has an importance value of " \
- f"{self.sensitivity.importances[c_id]:.2f}")
+ print(f"Concept {c_id} has an importance value of "
+ f"{c_id_importance:.2f}")
plt.figure(figsize=(7, (2.5/2)*ceil(nb_crops/5)))
for i in range(nb_crops):
plt.subplot(ceil(nb_crops/5), 5, i+1)
@@ -493,7 +568,7 @@ def plot_concepts_crops(self,
def plot_concept_attribution_legend(self,
nb_most_important_concepts: int = 6,
- border_width: int=5):
+ border_width: int = 5):
"""
Plot a legend for the concepts attribution maps.
@@ -511,13 +586,14 @@ def plot_concept_attribution_legend(self,
cmap = self.sensitivity.cmaps[i]
plt.subplot(1, len(most_important_concepts), i+1)
- best_crops_id = np.argsort(self.factorization.crops_u[:, c_id])[::-1][0]
+ best_crops_id = np.argsort(
+ self.factorization.crops_u[:, c_id])[::-1][0]
best_crop = self.factorization.crops[best_crops_id]
if best_crop.shape[0] > best_crop.shape[-1]:
- mask = np.zeros(best_crop.shape[:-1]) # tf
+ mask = np.zeros(best_crop.shape[:-1]) # tf
else:
- mask = np.zeros(best_crop.shape[1:]) # torch
+ mask = np.zeros(best_crop.shape[1:]) # torch
mask[:border_width, :] = 1.0
mask[:, :border_width] = 1.0
mask[-border_width:, :] = 1.0
@@ -569,22 +645,23 @@ def compute_subplots_layout_parameters(images: np.ndarray,
# get width and height of our images
if images.shape[1] == 3:
nb_samples = images.shape[0]
- [l_width, l_height] = images.shape[2:4] # pytorch
+ [l_width, l_height] = images.shape[2:4] # pytorch
else:
- [nb_samples, l_width, l_height] = images.shape[0:3] # tf
+ [nb_samples, l_width, l_height] = images.shape[0:3] # tf
rows = ceil(nb_samples / cols)
# define the figure margin, width, height in inch
figwidth = cols * img_size + (cols-1) * spacing + 2 * margin
- figheight = rows * img_size * l_height/l_width + (rows-1) * spacing + 2 * margin
+ figheight = rows * img_size * l_height / \
+ l_width + (rows-1) * spacing + 2 * margin
layout_parameters = {
- 'left' : margin/figwidth,
- 'right' : 1.-(margin/figwidth),
- 'bottom' : margin/figheight,
- 'top' : 1.-(margin/figheight),
- 'wspace' : spacing/img_size,
- 'hspace' : spacing/img_size * l_width/l_height
+ 'left': margin/figwidth,
+ 'right': 1.-(margin/figwidth),
+ 'bottom': margin/figheight,
+ 'top': 1.-(margin/figheight),
+ 'wspace': spacing/img_size,
+ 'hspace': spacing/img_size * l_width/l_height
}
return layout_parameters, rows, figwidth, figheight
@@ -630,7 +707,8 @@ def plot_concept_attribution_maps(self,
Additional parameters passed to `plt.imshow()`.
"""
- self.plot_concept_attribution_legend(nb_most_important_concepts=nb_most_important_concepts)
+ self.plot_concept_attribution_legend(
+ nb_most_important_concepts=nb_most_important_concepts)
# Take into account single vs multiple images array
if len(images.shape) == 3:
@@ -638,7 +716,8 @@ def plot_concept_attribution_maps(self,
# Configure the subplots
layout_configuration, rows, figwidth, figheight = \
- self.compute_subplots_layout_parameters(images=images, cols=cols, img_size=img_size)
+ self.compute_subplots_layout_parameters(
+ images=images, cols=cols, img_size=img_size)
fig = plt.figure()
fig.set_size_inches(figwidth, figheight)
@@ -666,13 +745,13 @@ def plot_concept_attribution_maps(self,
**plot_kwargs)
def plot_concept_attribution_map(self,
- image: np.ndarray,
- most_important_concepts: np.ndarray,
- nb_most_important_concepts: int = 5,
- filter_percentile: int = 90,
- clip_percentile: Optional[float] = 10,
- alpha: float = 0.65,
- **plot_kwargs):
+ image: np.ndarray,
+ most_important_concepts: np.ndarray,
+ nb_most_important_concepts: int = 5,
+ filter_percentile: int = 90,
+ clip_percentile: Optional[float] = 10,
+ alpha: float = 0.65,
+ **plot_kwargs):
"""
Display the concepts attribution map for a single image given in argument.
@@ -702,15 +781,16 @@ def plot_concept_attribution_map(self,
most_important_concepts = most_important_concepts[:nb_most_important_concepts]
# Find the colors corresponding to the importances
- global_color_index_order = np.argsort(self.sensitivity.importances)[::-1]
+ global_color_index_order = np.argsort(
+ self.sensitivity.importances)[::-1]
local_color_index_order = [np.where(global_color_index_order == local_c)[0][0]
for local_c in most_important_concepts]
local_cmap = np.array(self.sensitivity.cmaps)[local_color_index_order]
if image.shape[0] == 3:
- dsize = image.shape[1:3] # pytorch
+ dsize = image.shape[1:3] # pytorch
else:
- dsize = image.shape[0:2] # tf
+ dsize = image.shape[0:2] # tf
BaseCraft._show(image, **plot_kwargs)
image_u = self.transform(image)[0]
@@ -718,7 +798,8 @@ def plot_concept_attribution_map(self,
heatmap = image_u[:, :, c_id]
# only show concept if excess N-th percentile
- sigma = np.percentile(np.array(heatmap).flatten(), filter_percentile)
+ sigma = np.percentile(
+ np.array(heatmap).flatten(), filter_percentile)
heatmap = heatmap * np.array(heatmap > sigma, np.float32)
# resize the heatmap before cliping
@@ -727,12 +808,13 @@ def plot_concept_attribution_map(self,
if clip_percentile:
heatmap = _clip_percentile(heatmap, clip_percentile)
- BaseCraft._show(heatmap, cmap=local_cmap[::-1][i], alpha=alpha, **plot_kwargs)
+ BaseCraft._show(
+ heatmap, cmap=local_cmap[::-1][i], alpha=alpha, **plot_kwargs)
def plot_image_concepts(self,
img: np.ndarray,
- display_importance_order: DisplayImportancesOrder = \
- DisplayImportancesOrder.GLOBAL,
+ display_importance_order: DisplayImportancesOrder =
+ DisplayImportancesOrder.GLOBAL,
nb_most_important_concepts: int = 5,
filter_percentile: int = 90,
clip_percentile: Optional[float] = 10,
@@ -776,7 +858,8 @@ def plot_image_concepts(self,
if display_importance_order == DisplayImportancesOrder.LOCAL:
# compute the importances for the sample input in argument
importances = self.estimate_importance(inputs=img)
- most_important_concepts = np.argsort(importances)[::-1][:nb_most_important_concepts]
+ most_important_concepts = np.argsort(
+ importances)[::-1][:nb_most_important_concepts]
else:
# use the global importances computed on the whole dataset
importances = self.sensitivity.importances
@@ -787,7 +870,8 @@ def plot_image_concepts(self,
# the crops, and the central part to display the heatmap
nb_rows = ceil(len(most_important_concepts) / 2.0)
nb_cols = 4
- gs_main = fig.add_gridspec(nb_rows, nb_cols, hspace=0.4, width_ratios=[0.2, 0.4, 0.2, 0.4])
+ gs_main = fig.add_gridspec(
+ nb_rows, nb_cols, hspace=0.4, width_ratios=[0.2, 0.4, 0.2, 0.4])
# Central image
#
@@ -802,8 +886,8 @@ def plot_image_concepts(self,
# Concepts: creation of the axes on left and right of the image for the concepts
#
- gs_concepts_axes = [gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_main[i, 0])
- for i in range(nb_rows)]
+ gs_concepts_axes = [gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_main[i, 0])
+ for i in range(nb_rows)]
gs_right = [gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_main[i, 2])
for i in range(nb_rows)]
gs_concepts_axes.extend(gs_right)
@@ -812,7 +896,8 @@ def plot_image_concepts(self,
nb_crops = 6
# compute the right color to use for the crops
- global_color_index_order = np.argsort(self.sensitivity.importances)[::-1]
+ global_color_index_order = np.argsort(
+ self.sensitivity.importances)[::-1]
local_color_index_order = [np.where(global_color_index_order == local_c)[0][0]
for local_c in most_important_concepts]
local_cmap = np.array(self.sensitivity.cmaps)[local_color_index_order]
@@ -821,22 +906,24 @@ def plot_image_concepts(self,
cmap = local_cmap[i]
# use a ghost invisible subplot only to have a border around the crops
- ghost_axe = fig.add_subplot(gs_concepts_axes[i][:,:])
+ ghost_axe = fig.add_subplot(gs_concepts_axes[i][:, :])
ghost_axe.set_title(f"{c_id}", color=cmap(1.0))
ghost_axe.axis('off')
- inset_axes = ghost_axe.inset_axes([-0.04, -0.04, 1.08, 1.08]) # outer border
+ inset_axes = ghost_axe.inset_axes(
+ [-0.04, -0.04, 1.08, 1.08]) # outer border
inset_axes.set_xticks([])
inset_axes.set_yticks([])
- for spine in inset_axes.spines.values(): # border color
+ for spine in inset_axes.spines.values(): # border color
spine.set_edgecolor(color=cmap(1.0))
spine.set_linewidth(3)
# draw each crop for this concept
- gs_current = gridspec.GridSpecFromSubplotSpec(2, 3, subplot_spec=
- gs_concepts_axes[i][:,:])
+ gs_current = gridspec.GridSpecFromSubplotSpec(
+ 2, 3, subplot_spec=gs_concepts_axes[i][:, :])
- best_crops_ids = np.argsort(self.factorization.crops_u[:, c_id])[::-1][:nb_crops]
+ best_crops_ids = np.argsort(self.factorization.crops_u[:, c_id])[
+ ::-1][:nb_crops]
best_crops = np.array(self.factorization.crops)[best_crops_ids]
for i in range(nb_crops):
axe = plt.Subplot(fig, gs_current[i // 3, i % 3])
@@ -844,9 +931,10 @@ def plot_image_concepts(self,
BaseCraft._show(best_crops[i])
# Right plot: importances
- importance_axe = gridspec.GridSpecFromSubplotSpec(3, 2, width_ratios=[0.1, 0.9],
- height_ratios=[0.15, 0.6, 0.15],
- subplot_spec=gs_main[:, 3])
+ importance_axe = gridspec.GridSpecFromSubplotSpec(3, 2, width_ratios=[0.1, 0.9],
+ height_ratios=[
+ 0.15, 0.6, 0.15],
+ subplot_spec=gs_main[:, 3])
fig.add_subplot(importance_axe[1, 1])
self.plot_concepts_importances(importances=importances,
display_importance_order=display_importance_order,
diff --git a/xplique/concepts/craft_manager.py b/xplique/concepts/craft_manager.py
index fb94bf17..7b286bd5 100644
--- a/xplique/concepts/craft_manager.py
+++ b/xplique/concepts/craft_manager.py
@@ -6,9 +6,10 @@
import numpy as np
-from ..types import Callable, Optional
+from ..types import Callable, Optional, Union, Tuple
from .craft import DisplayImportancesOrder
+
class BaseCraftManager(ABC):
"""
Base class implementing the CRAFT Concept Extraction Mechanism on multiple classes.
@@ -31,11 +32,11 @@ class BaseCraftManager(ABC):
@abstractmethod
def __init__(self,
- input_to_latent_model : Callable,
- latent_to_logit_model : Callable,
- inputs : np.ndarray,
- labels : np.ndarray,
- list_of_class_of_interest : Optional[list] = None):
+ input_to_latent_model: Callable,
+ latent_to_logit_model: Callable,
+ inputs: np.ndarray,
+ labels: np.ndarray,
+ list_of_class_of_interest: Optional[list] = None):
self.input_to_latent_model = input_to_latent_model
self.latent_to_logit_model = latent_to_logit_model
self.inputs = inputs
@@ -45,7 +46,7 @@ def __init__(self,
# take all the classes
list_of_class_of_interest = np.array(list(set(labels)))
self.list_of_class_of_interest = list_of_class_of_interest
- self.craft_instances = None
+ self.craft_instances = {}
@abstractmethod
def compute_predictions(self):
@@ -61,7 +62,10 @@ def compute_predictions(self):
"""
raise NotImplementedError
- def fit(self, nb_samples_per_class: Optional[int] = None, verbose: bool = False):
+ def fit(self,
+ nb_samples_per_class: Optional[int] = None,
+ alpha_w: float = 1e-2,
+ verbose: bool = False):
"""
Fit the Craft models on their respective class of interest.
@@ -70,6 +74,8 @@ def fit(self, nb_samples_per_class: Optional[int] = None, verbose: bool = False)
nb_samples_per_class
Number of samples to use to fit the Craft model.
Default is None, which means that all the samples will be used.
+ alpha_w
+ Constant that multiplies the NMF regularization terms of `W`.
verbose
If True, then print the current class CRAFT is fitting,
otherwise no textual output will be printed.
@@ -85,7 +91,8 @@ def fit(self, nb_samples_per_class: Optional[int] = None, verbose: bool = False)
if nb_samples_per_class is not None:
class_inputs = class_inputs[:nb_samples_per_class]
class_labels = class_labels[:nb_samples_per_class]
- craft_instance.fit(class_inputs, class_id=class_of_interest)
+ craft_instance.fit(
+ class_inputs, class_id=class_of_interest, alpha_w=alpha_w)
def estimate_importance(self, nb_design: int = 32, verbose: bool = False):
"""
@@ -104,6 +111,32 @@ def estimate_importance(self, nb_design: int = 32, verbose: bool = False):
print(f'Estimating importances for class {class_of_interest} ')
craft_instance.estimate_importance(nb_design=nb_design)
+ def set_concept_attribution_cmap(self,
+ class_id: int,
+ cmaps: Optional[Union[Tuple, str]] = None):
+ """
+ Set the colormap used for the concepts displayed in the attribution maps,
+ form the class 'class_id'.
+
+ Parameters
+ ----------
+ class_id
+ The class to explain.
+ cmaps
+ The list of colors associated with each concept.
+ Can be either:
+ - A list of (r, g, b) colors to use as a base for the colormap.
+ - A colormap name compatible with `plt.get_cmap(cmap)`.
+ """
+ self.craft_instances[class_id].sensitivity.set_concept_attribution_cmap(
+ cmaps)
+
+
+class CraftManagerImageVisualizationMixin():
+ """
+ Class containing image visualization methods for CraftManager.
+ """
+
def plot_concepts_importances(self,
class_id: int,
nb_most_important_concepts: int = 5,
@@ -121,9 +154,10 @@ def plot_concepts_importances(self,
If True, then print the importance value of each concept, otherwise no textual
output will be printed.
"""
- self.craft_instances[class_id].plot_concepts_importances(importances = None,
- nb_most_important_concepts=nb_most_important_concepts,
- verbose=verbose)
+ self.craft_instances[class_id]\
+ .plot_concepts_importances(importances=None,
+ nb_most_important_concepts=nb_most_important_concepts,
+ verbose=verbose)
def plot_concepts_crops(self,
class_id: int, nb_crops: int = 10,
@@ -143,13 +177,13 @@ def plot_concepts_crops(self,
Default is None.
"""
self.craft_instances[class_id].plot_concepts_crops(nb_crops=nb_crops,
- nb_most_important_concepts=nb_most_important_concepts)
+ nb_most_important_concepts=nb_most_important_concepts)
def plot_image_concepts(self,
img: np.ndarray,
class_id: int,
- display_importance_order: DisplayImportancesOrder = \
- DisplayImportancesOrder.GLOBAL,
+ display_importance_order: DisplayImportancesOrder =
+ DisplayImportancesOrder.GLOBAL,
nb_most_important_concepts: int = 5,
filter_percentile: int = 90,
clip_percentile: Optional[float] = 10,
@@ -188,9 +222,9 @@ def plot_image_concepts(self,
Path the file will be saved at. If None, the function will call plt.show().
"""
self.craft_instances[class_id].plot_image_concepts(img,
- display_importance_order=display_importance_order,
- nb_most_important_concepts=nb_most_important_concepts,
- filter_percentile=filter_percentile,
- clip_percentile=clip_percentile,
- alpha=alpha,
- filepath=filepath)
+ display_importance_order=display_importance_order,
+ nb_most_important_concepts=nb_most_important_concepts,
+ filter_percentile=filter_percentile,
+ clip_percentile=clip_percentile,
+ alpha=alpha,
+ filepath=filepath)
diff --git a/xplique/concepts/craft_tf.py b/xplique/concepts/craft_tf.py
index 1569cdc9..7a95732b 100644
--- a/xplique/concepts/craft_tf.py
+++ b/xplique/concepts/craft_tf.py
@@ -7,13 +7,13 @@
import tensorflow as tf
import numpy as np
-from .craft import BaseCraft
-from .craft_manager import BaseCraftManager
+from .craft import BaseCraft, CraftImageVisualizationMixin
+from .craft_manager import BaseCraftManager, CraftManagerImageVisualizationMixin
-class CraftTf(BaseCraft):
+class BaseCraftTf(BaseCraft):
"""
- Class implementing the CRAFT Concept Extraction Mechanism on Tensorflow.
+ Base class implementing the CRAFT Concept Extraction Mechanism on Tensorflow.
Parameters
----------
@@ -33,11 +33,12 @@ class CraftTf(BaseCraft):
patch_size
The size of the patches to extract from the input data. Default is 64.
"""
- def __init__(self, input_to_latent_model : Callable,
- latent_to_logit_model : Callable,
- number_of_concepts: int = 20,
- batch_size: int = 64,
- patch_size: int = 64):
+
+ def __init__(self, input_to_latent_model: Callable,
+ latent_to_logit_model: Callable,
+ number_of_concepts: int = 20,
+ batch_size: int = 64,
+ patch_size: int = 64):
super().__init__(input_to_latent_model,
latent_to_logit_model,
number_of_concepts,
@@ -48,9 +49,9 @@ def __init__(self, input_to_latent_model : Callable,
keras_base_layer = tf.keras.Model
is_tf_model = issubclass(type(input_to_latent_model), keras_base_layer) & \
- issubclass(type(latent_to_logit_model), keras_base_layer)
+ issubclass(type(latent_to_logit_model), keras_base_layer)
if not is_tf_model:
- raise TypeError('input_to_latent_model and latent_to_logit_model are not '\
+ raise TypeError('input_to_latent_model and latent_to_logit_model are not '
'Tensorflow models')
def _latent_predict(self, inputs: tf.Tensor):
@@ -110,16 +111,19 @@ def _extract_patches(self, inputs: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
strides = int(self.patch_size * 0.80)
patches = tf.image.extract_patches(images=inputs,
- sizes=[1, self.patch_size, self.patch_size, 1],
+ sizes=[1, self.patch_size,
+ self.patch_size, 1],
strides=[1, strides, strides, 1],
rates=[1, 1, 1, 1],
padding='VALID')
- patches = tf.reshape(patches, (-1, self.patch_size, self.patch_size, inputs.shape[-1]))
+ patches = tf.reshape(
+ patches, (-1, self.patch_size, self.patch_size, inputs.shape[-1]))
# encode the patches and obtain the activations
input_width, input_height = inputs.shape[1], inputs.shape[2]
activations = self._latent_predict(tf.image.resize(patches,
- (input_width, input_height),
+ (input_width,
+ input_height),
method="bicubic"))
assert np.min(activations) >= 0.0, "Activations must be positive."
@@ -138,9 +142,34 @@ def _to_np_array(self, inputs: tf.Tensor, dtype: type):
return np.array(inputs, dtype)
-class CraftManagerTf(BaseCraftManager):
+class CraftTf(BaseCraftTf, CraftImageVisualizationMixin):
"""
- Class implementing the CraftManager on Tensorflow.
+ Base class implementing the CRAFT Concept Extraction Mechanism on Tensorflow,
+ adapted for image processing.
+
+ Parameters
+ ----------
+ input_to_latent_model
+ The first part of the model taking an input and returning
+ positive activations, g(.) in the original paper.
+ Must be a Tensorflow model (tf.keras.engine.base_layer.Layer) accepting
+ data of shape (n_samples, height, width, channels).
+ latent_to_logit_model
+ The second part of the model taking activation and returning
+ logits, h(.) in the original paper.
+ Must be a Tensorflow model (tf.keras.engine.base_layer.Layer).
+ number_of_concepts
+ The number of concepts to extract. Default is 20.
+ batch_size
+ The batch size to use during training and prediction. Default is 64.
+ patch_size
+ The size of the patches to extract from the input data. Default is 64.
+ """
+
+
+class BaseCraftManagerTf(BaseCraftManager):
+ """
+ Base class implementing the CraftManager on Tensorflow.
This manager creates one CraftTf instance per class to explain.
Parameters
@@ -167,19 +196,19 @@ class CraftManagerTf(BaseCraftManager):
patch_size
The size of the patches (crops) to extract from the input data. Default is 64.
"""
- def __init__(self, input_to_latent_model : Callable,
- latent_to_logit_model : Callable,
- inputs : np.ndarray,
- labels : np.ndarray,
- list_of_class_of_interest : Optional[list] = None,
- number_of_concepts: int = 20,
- batch_size: int = 64,
- patch_size: int = 64):
+
+ def __init__(self, input_to_latent_model: Callable,
+ latent_to_logit_model: Callable,
+ inputs: np.ndarray,
+ labels: np.ndarray,
+ list_of_class_of_interest: Optional[list] = None,
+ number_of_concepts: int = 20,
+ batch_size: int = 64,
+ patch_size: int = 64):
super().__init__(input_to_latent_model, latent_to_logit_model,
inputs, labels, list_of_class_of_interest)
- self.craft_instances = {}
for class_of_interest in self.list_of_class_of_interest:
craft = CraftTf(input_to_latent_model, latent_to_logit_model,
number_of_concepts, batch_size, patch_size)
@@ -196,5 +225,36 @@ def compute_predictions(self):
the predictions
"""
y_preds = np.array(tf.argmax(self.latent_to_logit_model.predict(
- self.input_to_latent_model.predict(self.inputs)), 1))
+ self.input_to_latent_model.predict(self.inputs)), 1))
return y_preds
+
+
+class CraftManagerTf(BaseCraftManagerTf, CraftManagerImageVisualizationMixin):
+ """
+ Class implementing the CraftManager on Tensorflow, adapted for image processing.
+ This manager creates one CraftTf instance per class to explain.
+
+ Parameters
+ ----------
+ input_to_latent_model
+ The first part of the model taking an input and returning
+ positive activations, g(.) in the original paper.
+ Must return positive activations.
+ latent_to_logit_model
+ The second part of the model taking activation and returning
+ logits, h(.) in the original paper.
+ inputs
+ Input data of shape (n_samples, height, width, channels).
+ (x1, x2, ..., xn) in the paper.
+ labels
+ Labels of the inputs of shape (n_samples, class_id)
+ list_of_class_of_interest
+ A list of the classes id to explain. The manager will instanciate one
+ CraftTf object per element of this list.
+ number_of_concepts
+ The number of concepts to extract. Default is 20.
+ batch_size
+ The batch size to use during training and prediction. Default is 64.
+ patch_size
+ The size of the patches (crops) to extract from the input data. Default is 64.
+ """
diff --git a/xplique/concepts/craft_torch.py b/xplique/concepts/craft_torch.py
index f429cce4..2b659df5 100644
--- a/xplique/concepts/craft_torch.py
+++ b/xplique/concepts/craft_torch.py
@@ -3,19 +3,21 @@
"""
from typing import Callable, Optional, Tuple
+from types import MethodType
from math import ceil
import torch
from torch import nn
import numpy as np
-from .craft import BaseCraft
-from .craft_manager import BaseCraftManager
+from .craft import BaseCraft, CraftImageVisualizationMixin
+from .craft_manager import BaseCraftManager, CraftManagerImageVisualizationMixin
+
def _batch_inference(model: torch.nn.Module,
dataset: torch.Tensor,
batch_size: int = 128,
resize: Optional[int] = None,
- device: str='cuda') -> torch.Tensor:
+ device: str = 'cuda') -> torch.Tensor:
"""
Compute the model predictions of the input images.
@@ -59,9 +61,9 @@ def _batch_inference(model: torch.nn.Module,
return results
-class CraftTorch(BaseCraft):
+class BaseCraftTorch(BaseCraft):
"""
- Class Implementing the CRAFT Concept Extraction Mechanism on Pytorch.
+ Base class implementing the CRAFT Concept Extraction Mechanism on Pytorch.
Parameters
----------
@@ -85,22 +87,25 @@ class CraftTorch(BaseCraft):
"""
def __init__(self, input_to_latent_model: Callable,
- latent_to_logit_model: Callable,
- number_of_concepts: int = 20,
- batch_size: int = 64,
- patch_size: int = 64,
- device : str = 'cuda'):
+ latent_to_logit_model: Callable,
+ number_of_concepts: int = 20,
+ batch_size: int = 64,
+ patch_size: int = 64,
+ device: str = 'cuda'):
super().__init__(input_to_latent_model, latent_to_logit_model,
number_of_concepts, batch_size)
self.patch_size = patch_size
self.device = device
# Check model type
+ is_method = isinstance(input_to_latent_model, MethodType) & \
+ isinstance(latent_to_logit_model, MethodType)
is_torch_model = issubclass(type(input_to_latent_model), torch.nn.modules.module.Module) & \
- issubclass(type(latent_to_logit_model), torch.nn.modules.module.Module)
- if not is_torch_model:
- raise TypeError('input_to_latent_model and latent_to_logit_model are not ' \
- 'Pytorch modules')
+ issubclass(type(latent_to_logit_model),
+ torch.nn.modules.module.Module)
+ if not (is_method or is_torch_model):
+ raise TypeError('input_to_latent_model and latent_to_logit_model are not '
+ 'Pytorch modules nor methods')
def _latent_predict(self, inputs: torch.Tensor, resize=None) -> torch.Tensor:
"""
@@ -118,7 +123,8 @@ def _latent_predict(self, inputs: torch.Tensor, resize=None) -> torch.Tensor:
"""
# inputs: (N, C, H, W)
if len(inputs.shape) == 3:
- inputs = inputs.unsqueeze(0) # add an extra dim in case we get only 1 image to predict
+ # add an extra dim in case we get only 1 image to predict
+ inputs = inputs.unsqueeze(0)
activations = _batch_inference(self.input_to_latent_model, inputs,
self.batch_size, resize, device=self.device)
@@ -147,7 +153,8 @@ def _logit_predict(self, activations: np.ndarray, resize=None) -> torch.Tensor:
if len(activations_perturbated.shape) == 4:
# activations_perturbated: (N, H, W, C) -> (N, C, H, W)
- activations_perturbated = activations_perturbated.permute(0, 3, 1, 2)
+ activations_perturbated = activations_perturbated.permute(
+ 0, 3, 1, 2)
y_pred = _batch_inference(self.latent_to_logit_model, activations_perturbated,
self.batch_size, resize, device=self.device)
@@ -176,7 +183,8 @@ def _extract_patches(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, np.ndarr
# extract patches from the input data, keep patches on cpu
strides = int(self.patch_size * 0.80)
- patches = torch.nn.functional.unfold(inputs, kernel_size=self.patch_size, stride=strides)
+ patches = torch.nn.functional.unfold(
+ inputs, kernel_size=self.patch_size, stride=strides)
patches = patches.transpose(1, 2).contiguous().view(-1, num_channels,
self.patch_size, self.patch_size)
@@ -205,9 +213,36 @@ def _to_np_array(self, inputs: torch.Tensor, dtype: type = None):
return res
-class CraftManagerTorch(BaseCraftManager):
+class CraftTorch(BaseCraftTorch, CraftImageVisualizationMixin):
+ """
+ Class Implementing the CRAFT Concept Extraction Mechanism on Pytorch,
+ adpated for image processing.
+
+ Parameters
+ ----------
+ input_to_latent_model
+ The first part of the model taking an input and returning
+ positive activations, g(.) in the original paper.
+ Must be a Pytorch model (torch.nn.modules.module.Module) accepting
+ data of shape (n_samples, channels, height, width).
+ latent_to_logit_model
+ The second part of the model taking activation and returning
+ logits, h(.) in the original paper.
+ Must be a Pytorch model (torch.nn.modules.module.Module).
+ number_of_concepts
+ The number of concepts to extract. Default is 20.
+ batch_size
+ The batch size to use during training and prediction. Default is 64.
+ patch_size
+ The size of the patches (crops) to extract from the input data. Default is 64.
+ device
+ The device to use. Default is 'cuda'.
+ """
+
+
+class BaseCraftManagerTorch(BaseCraftManager):
"""
- Class implementing the CraftManager on Tensorflow.
+ Base class implementing the CraftManager on Pytorch.
This manager creates one CraftTorch instance per class to explain.
Parameters
@@ -234,22 +269,22 @@ class CraftManagerTorch(BaseCraftManager):
patch_size
The size of the patches (crops) to extract from the input data. Default is 64.
"""
- def __init__(self, input_to_latent_model : Callable,
- latent_to_logit_model : Callable,
- inputs : np.ndarray,
- labels : np.ndarray,
- list_of_class_of_interest : list = None,
- number_of_concepts: int = 20,
- batch_size: int = 64,
- patch_size: int = 64,
- device : str = 'cuda'):
+
+ def __init__(self, input_to_latent_model: Callable,
+ latent_to_logit_model: Callable,
+ inputs: np.ndarray,
+ labels: np.ndarray,
+ list_of_class_of_interest: list = None,
+ number_of_concepts: int = 20,
+ batch_size: int = 64,
+ patch_size: int = 64,
+ device: str = 'cuda'):
super().__init__(input_to_latent_model, latent_to_logit_model,
inputs, labels, list_of_class_of_interest)
self.batch_size = batch_size
self.device = device
- self.craft_instances = {}
for class_of_interest in self.list_of_class_of_interest:
craft = CraftTorch(input_to_latent_model, latent_to_logit_model,
number_of_concepts, batch_size, patch_size, device)
@@ -265,8 +300,41 @@ def compute_predictions(self):
y_preds
the predictions
"""
- model = nn.Sequential(self.input_to_latent_model, self.latent_to_logit_model)
+ model = nn.Sequential(self.input_to_latent_model,
+ self.latent_to_logit_model)
activations = _batch_inference(model, self.inputs, self.batch_size, None,
device=self.device)
- y_preds = np.array(torch.argmax(activations, -1)) # pylint disable=no-member
+ y_preds = np.array(torch.argmax(activations, -1)
+ ) # pylint disable=no-member
return y_preds
+
+
+class CraftManagerTorch(BaseCraftManagerTorch, CraftManagerImageVisualizationMixin):
+ """
+ Class implementing the CraftManager on Pytorch, adapted for image processing.
+ This manager creates one CraftTorch instance per class to explain.
+
+ Parameters
+ ----------
+ input_to_latent_model
+ The first part of the model taking an input and returning
+ positive activations, g(.) in the original paper.
+ Must return positive activations.
+ latent_to_logit_model
+ The second part of the model taking activation and returning
+ logits, h(.) in the original paper.
+ inputs
+ Input data of shape (n_samples, height, width, channels).
+ (x1, x2, ..., xn) in the paper.
+ labels
+ Labels of the inputs of shape (n_samples, class_id)
+ list_of_class_of_interest
+ A list of the classes id to explain. The manager will instanciate one
+ CraftTorch object per element of this list.
+ number_of_concepts
+ The number of concepts to extract. Default is 20.
+ batch_size
+ The batch size to use during training and prediction. Default is 64.
+ patch_size
+ The size of the patches (crops) to extract from the input data. Default is 64.
+ """
diff --git a/xplique/features_visualizations/preconditioning.py b/xplique/features_visualizations/preconditioning.py
index f5758f8d..fc98fcc6 100644
--- a/xplique/features_visualizations/preconditioning.py
+++ b/xplique/features_visualizations/preconditioning.py
@@ -16,13 +16,6 @@
IMAGENET_SPECTRUM_URL = "https://storage.googleapis.com/serrelab/loupe/"\
"spectrums/imagenet_decorrelated.npy"
-imagenet_color_correlation = tf.cast(
- [[0.56282854, 0.58447580, 0.58447580],
- [0.19482528, 0.00000000,-0.19482528],
- [0.04329450,-0.10823626, 0.06494176]], tf.float32
-)
-
-
def recorrelate_colors(images: tf.Tensor) -> tf.Tensor:
"""
Map uncorrelated colors to 'normal colors' by using empirical color
@@ -39,6 +32,11 @@ def recorrelate_colors(images: tf.Tensor) -> tf.Tensor:
images
Images recorrelated.
"""
+ imagenet_color_correlation = tf.cast(
+ [[0.56282854, 0.58447580, 0.58447580],
+ [0.19482528, 0.00000000,-0.19482528],
+ [0.04329450,-0.10823626, 0.06494176]], tf.float32
+ )
images_flat = tf.reshape(images, [-1, 3])
images_flat = tf.matmul(images_flat, imagenet_color_correlation)
return tf.reshape(images_flat, tf.shape(images))
diff --git a/xplique/utils_functions/object_detection.py b/xplique/utils_functions/object_detection.py
index c8bbbe6b..cc1c1016 100644
--- a/xplique/utils_functions/object_detection.py
+++ b/xplique/utils_functions/object_detection.py
@@ -4,9 +4,6 @@
from typing import Tuple
import tensorflow as tf
-_EPSILON = tf.constant(1e-4)
-
-
def _box_iou(boxes_a: tf.Tensor, boxes_b: tf.Tensor) -> tf.Tensor:
"""
Compute the intersection between two batched bounding boxes.
@@ -40,6 +37,7 @@ def _box_iou(boxes_a: tf.Tensor, boxes_b: tf.Tensor) -> tf.Tensor:
union_area = a_area + b_area - intersection_area
+ _EPSILON = tf.constant(1e-4)
iou_score = intersection_area / (union_area + _EPSILON)
return iou_score