diff --git a/tests/attributions/test_nlp_occlusion.py b/tests/attributions/test_nlp_occlusion.py new file mode 100644 index 00000000..d1b60b20 --- /dev/null +++ b/tests/attributions/test_nlp_occlusion.py @@ -0,0 +1,47 @@ +""" +Test object detection BoundingBoxesExplainer +""" +import numpy as np + +from xplique.attributions import NlpOcclusion + +def test_masks(): + """Test the masks creation""" + sentence = "aaa bbb ccc" + words = sentence.split(" ") + masks = NlpOcclusion._get_masks(words) + assert masks.shape == (len(words), len(words)) + expected_mask = np.array([[False, True, True], + [True, False, True], + [True, True, False]]) + + assert np.array_equal(masks, expected_mask) + +def test_apply_masks(): + """Test if the application of a mask generate valid results""" + sentence = "aaa bbb ccc" + words = sentence.split(" ") + masks = NlpOcclusion._get_masks(words) + + occluded_inputs = NlpOcclusion._apply_masks(words, masks) + expected_occludec_inputs = [['bbb', 'ccc'], ['aaa', 'ccc'], ['aaa', 'bbb']] + assert np.array_equal(occluded_inputs, expected_occludec_inputs) + +def test_output_shape(): + """Test the output shape for several input sentences""" + + nb_concepts = 10 + + def transform(inputs): + # simulate the transorm method used in Craft/Cockatiel + return np.ones((len(inputs), nb_concepts)) + + input_sentence = ["aaa bbb ccc ddd eee fff", "ggg hhh iii jjj"] + for sentence in input_sentence: + words = sentence.split(" ") + separator = " " + + method = NlpOcclusion(model=transform) + sensitivity = method.explain(sentence, words, separator) + + assert sensitivity.shape == (nb_concepts, len(words)) diff --git a/tests/concepts/test_cockatiel.py b/tests/concepts/test_cockatiel.py new file mode 100644 index 00000000..6f4e5e57 --- /dev/null +++ b/tests/concepts/test_cockatiel.py @@ -0,0 +1,240 @@ +import numpy as np +import torch +import torch.nn as nn + +from torch.nn import MSELoss +from transformers import RobertaPreTrainedModel, RobertaModel +from transformers import RobertaTokenizerFast +from xplique.concepts import CockatielTorch as Cockatiel +from xplique.commons.torch_operations import NlpPreprocessor + + +class FakeImdbClassifier(torch.nn.Module): + def __init__(self, nb_classes): + super().__init__() + self.nb_classes = nb_classes + + def features(self, **kwargs): + start_value = 0 + return torch.arange(start_value, + start_value + kwargs['input_ids'].shape[0] * self.nb_classes).\ + reshape(kwargs['input_ids'].shape[0], self.nb_classes) + + def classifier(self, latent): + # Simulates a classifier, returns alternatively [0, 1] and [1, 0] + # as result + n = len(latent) + pattern = torch.cat([torch.tensor([0, 1]), torch.tensor([1, 0])]) + ypred = torch.cat([pattern] * n).reshape(n * 2, -1) + return ypred + + def forward(self, **kwargs): + return self.classifier(self.features(**kwargs)) + + +class ImdbPreprocessor(NlpPreprocessor): + def preprocess(self, inputs: np.ndarray, labels: np.ndarray): + preprocessed_inputs = self.tokenize(samples=inputs.tolist()) + preprocessed_labels = torch.Tensor(np.array( + labels.tolist()) == 'positive').int().to(self.device) + return preprocessed_inputs, preprocessed_labels + + +def test_shape(): + """Ensure the output shape is correct""" + + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + pretrained_model_path = "wrmurray/roberta-base-finetuned-imdb" + tokenizer = RobertaTokenizerFast.from_pretrained(pretrained_model_path) + nb_classes = 2 + model = FakeImdbClassifier(nb_classes=nb_classes) + model = model.eval() + + imdb_preprocessor = ImdbPreprocessor(tokenizer, + device, + padding="max_length", + max_length=512, + truncation=True, + return_tensors='pt') + + input_to_latent_model = model.features + latent_to_logit_model = model.classifier + number_of_concepts = 20 + cockatiel_explainer_pos = Cockatiel(input_to_latent_model=input_to_latent_model, + latent_to_logit_model=latent_to_logit_model, + preprocessor=imdb_preprocessor, + number_of_concepts=number_of_concepts, + batch_size=64, + device=device) + + data = ["Absolutely riveting from beginning to end. Test 1 2 3. Another test.", + "But it's a great movie.", + "This is the best movie of the year to date.", + "The movie is excellent."] + crops, crops_u, concept_bank_w \ + = cockatiel_explainer_pos.fit(inputs=data, + class_id=1, alpha_w=0) + assert len(crops) == 6 + assert crops_u.shape == (6, number_of_concepts) + assert concept_bank_w.shape == (number_of_concepts, nb_classes) + + global_importance_pos = cockatiel_explainer_pos.estimate_importance() + assert len(global_importance_pos) == number_of_concepts + + +class CustomRobertaForSequenceClassification(RobertaPreTrainedModel): + """ + A custom RoBERTa model using a custom fully-connected head with a non-negative layer on which + we can compute the NMF. + + Parameters + ---------- + config + An object indicating the hidden layer size, the presence and amount of dropout for the + classification head. + """ + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = 1 + self.config = config + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None + else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, 2) + + self.roberta = RobertaModel(config, add_pooling_layer=False) + + self.mse_loss = MSELoss() + + self.post_init() + + def classifier_features(self, x): + x = self.dropout(x) + x = self.dense(x) + x = torch.relu(x) + return x + + def classifier_end_model(self, x): + x = self.dropout(x) + x = self.out_proj(x) + return x + + def classifier(self, x): + x = self.classifier_features(x) + x = self.classifier_end_model(x) + return x + + def features( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # chain RobertaModel and the classifier features ending with a ReLU + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0][:, 0, :] + return self.classifier_features(sequence_output) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + activations = self.features( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = self.classifier_end_model(activations) + return logits + + +def test_classifier(): + + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + pretrained_model_path = "wrmurray/roberta-base-finetuned-imdb" + tokenizer = RobertaTokenizerFast.from_pretrained(pretrained_model_path) + model = CustomRobertaForSequenceClassification.from_pretrained( + pretrained_model_path).to(device) + model = model.eval() + + imdb_preprocessor = ImdbPreprocessor(tokenizer, + device, + padding="max_length", + max_length=512, + truncation=True, + return_tensors='pt') + + input_to_latent_model = model.features + latent_to_logit_model = model.classifier + number_of_concepts = 2 + cockatiel_explainer_pos = Cockatiel(input_to_latent_model=input_to_latent_model, + latent_to_logit_model=latent_to_logit_model, + preprocessor=imdb_preprocessor, + number_of_concepts=number_of_concepts, + batch_size=64, + device=device) + + data = ["Absolutely great riveting from beginning to end. It was great. Great great great.", + "But it's a great movie.", + "I liked this film.", + "I enjoyed the scenario of this movie."] # sentence with 'great', othes not + crops, crops_u, concept_bank_w \ + = cockatiel_explainer_pos.fit(inputs=data, + class_id=1, alpha_w=0) + assert len(crops) == 6 + assert crops_u.shape == (6, number_of_concepts) + assert concept_bank_w.shape[0] == number_of_concepts + + global_importance_pos = cockatiel_explainer_pos.estimate_importance() + assert len(global_importance_pos) == number_of_concepts + most_important_concepts_ids = global_importance_pos.argsort() + + nb_excerpts = 2 + nb_most_important_concepts = 2 + + best_sentences_per_concept = cockatiel_explainer_pos.get_best_excerpts_per_concept( + nb_excerpts=nb_excerpts, nb_most_important_concepts=nb_most_important_concepts) + assert np.all(best_sentences_per_concept[most_important_concepts_ids[0]] == [ + 'I enjoyed the scenario of this movie.', 'I liked this film.']) + assert np.all(best_sentences_per_concept[most_important_concepts_ids[1]] == [ + 'Great great great.', 'It was great.']) diff --git a/tests/nlp/__init__.py b/tests/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nlp/test_preprocessor.py b/tests/nlp/test_preprocessor.py new file mode 100644 index 00000000..9c888b68 --- /dev/null +++ b/tests/nlp/test_preprocessor.py @@ -0,0 +1,96 @@ +import numpy as np + +import torch +import pytest +from transformers import RobertaTokenizerFast +from xplique.commons.torch_operations import NlpPreprocessor, batcher, nlp_batch_predict + + +class ImdbPreprocessor(NlpPreprocessor): + def preprocess(self, inputs: np.ndarray, labels: np.ndarray): + preprocessed_inputs = self.tokenize(samples=inputs.tolist()) + preprocessed_labels = torch.Tensor(np.array( + labels.tolist()) == 'positive').int().to(self.device) + return preprocessed_inputs, preprocessed_labels + + +@pytest.fixture +def imdb_preprocessor(): + pretrained_model_path = "wrmurray/roberta-base-finetuned-imdb" + tokenizer = RobertaTokenizerFast.from_pretrained(pretrained_model_path) + return ImdbPreprocessor(tokenizer, + device='cpu', + padding="max_length", + max_length=512, + truncation=True, + return_tensors='pt') + + +def test_imdb_preprocessor(imdb_preprocessor): + inputs = np.array(['This is an example sentence.', + 'Another sentence for testing.']) + labels = np.array(['positive', 'negative']) + x_preprocessed, y_preprocessed = imdb_preprocessor.preprocess( + inputs, labels) + assert 'input_ids' in x_preprocessed + assert 'attention_mask' in x_preprocessed + assert x_preprocessed['input_ids'].shape == x_preprocessed['attention_mask'].shape == ( + 2, 512) + assert y_preprocessed.shape == (2,) + assert y_preprocessed[0] == 1 + assert y_preprocessed[1] == 0 + + +def test_batcher(): + # Test case 1: elements perfectly divisible by batch_size + elements1 = [1, 2, 3, 4, 5, 6] + batch_size1 = 2 + batches1 = list(batcher(elements1, batch_size1)) + assert batches1 == [[1, 2], [3, 4], [5, 6]] + + # Test case 2: elements not perfectly divisible by batch_size + elements2 = [1, 2, 3, 4, 5] + batch_size2 = 2 + batches2 = list(batcher(elements2, batch_size2)) + assert batches2 == [[1, 2], [3, 4], [5]] + + # Test case 3: batch_size greater than the length of elements + elements3 = [1, 2, 3, 4, 5] + batch_size3 = 10 + batches3 = list(batcher(elements3, batch_size3)) + assert batches3 == [[1, 2, 3, 4, 5]] + + # Test case 4: batch_size equal to the length of elements + elements4 = [1, 2, 3, 4, 5] + batch_size4 = 5 + batches4 = list(batcher(elements4, batch_size4)) + assert batches4 == [[1, 2, 3, 4, 5]] + + # Test case 5: with a complex list of elements + elements5 = list(zip([1, 2, 3, 4, 5], ['a', 'b', 'c', 'd', 'e'])) + batch_size5 = 2 + batches5 = list(batcher(elements5, batch_size5)) + assert batches5 == [[(1, 'a'), (2, 'b')], [(3, 'c'), (4, 'd')], [(5, 'e')]] + + +class ImdbClassifier(torch.nn.Module): + def __init__(self, nb_classes): + super().__init__() + self.nb_classes = nb_classes + + def forward(self, **kwargs): + return torch.randn(kwargs['input_ids'].shape[0], self.nb_classes) + + +def test_batch_predict(imdb_preprocessor): + inputs = ['text1', 'text2', 'text3'] + labels = ['positive', 'negative', 'positive'] + + batch_size = 2 + nb_classes = 2 + mock_model = ImdbClassifier(nb_classes) + + predictions, processed_labels = nlp_batch_predict( + mock_model, imdb_preprocessor, inputs, labels, batch_size) + + assert len(predictions) == len(processed_labels) == len(inputs) diff --git a/tests/nlp/test_token_extractor.py b/tests/nlp/test_token_extractor.py new file mode 100644 index 00000000..7907390f --- /dev/null +++ b/tests/nlp/test_token_extractor.py @@ -0,0 +1,203 @@ +from xplique.commons.nlp import WordExtractor, SentenceExtractor +from xplique.commons.nlp import FlairClauseExtractor, ExcerptExtractor, ExtractorFactory, SpacyClauseExtractor +from flair.models import SequenceTagger + +import spacy + +import pytest + + +@pytest.fixture +def example_sentence(): + return "One two three. Second sentence.Third Sentence, test1, test2; test3-test4 .GO!"\ + " Trust me,. sentence not starting with capital letter."\ + " Sentence with dots..Word, Word, word,...word ....so a sentence" + + +@pytest.fixture +def chunk_english_tagger(): + return SequenceTagger.load("flair/chunk-english") + + +@pytest.fixture +def web_sm_pipeline(): + spacy.cli.download("en_core_web_sm") + return spacy.load("en_core_web_sm") + + +def test_word_extractor(example_sentence): + extractor = WordExtractor() + tokens, separator = extractor.extract_tokens(example_sentence) + assert isinstance(tokens, list) + assert isinstance(separator, str) + assert separator == ' ' + expected_tokens = ['One', 'two', 'three', '.', + 'Second', 'sentence.Third', 'Sentence', ',', 'test1', ',', 'test2', ';', + 'test3-test4', '.GO', '!', 'Trust', 'me', ',', '.', 'sentence', 'not', + 'starting', 'with', 'capital', 'letter', '.', 'Sentence', 'with', 'dots', + '..', 'Word', ',', 'Word', ',', 'word', ',', '...', 'word', '....', 'so', + 'a', 'sentence'] + assert tokens == expected_tokens, print('tokens:', tokens) + + +def test_word_extractor_ignore_words(example_sentence): + extractor = WordExtractor() + tokens, separator = extractor.extract_tokens(example_sentence) + assert isinstance(tokens, list) + assert isinstance(separator, str) + assert separator == ' ' + expected_tokens = ['One', 'two', 'three', '.', + 'Second', 'sentence.Third', 'Sentence', ',', 'test1', ',', 'test2', ';', + 'test3-test4', '.GO', '!', 'Trust', 'me', ',', '.', 'sentence', 'not', + 'starting', 'with', 'capital', 'letter', '.', 'Sentence', 'with', 'dots', + '..', 'Word', ',', 'Word', ',', 'word', ',', '...', 'word', '....', + 'so', 'a', 'sentence'] + assert tokens == expected_tokens, print('tokens:', tokens) + + +def test_word_extractor_from_list(example_sentence): + extractor = WordExtractor() + tokens, separator = extractor.extract_tokens( + [example_sentence, example_sentence]) + assert isinstance(tokens, list) + assert isinstance(separator, str) + assert separator == ' ' + + +def test_sentence_extractor(example_sentence): + extractor = SentenceExtractor() + tokens, separator = extractor.extract_tokens(example_sentence) + assert isinstance(tokens, list) + assert isinstance(separator, str) + assert separator == '. ' + expected_tokens = ['One two three.', + 'Second sentence.Third Sentence, test1, test2; test3-test4 .GO!', + 'Trust me,.', + 'sentence not starting with capital letter.', + 'Sentence with dots..Word, Word, word,...word ....so a sentence'] + assert tokens == expected_tokens, print('tokens:', tokens) + + +def test_excerpt_extractor(example_sentence): + extractor = ExcerptExtractor() + tokens, separator = extractor.extract_tokens(example_sentence) + assert isinstance(tokens, list) + assert isinstance(separator, str) + assert separator == ' ' + expected_tokens = ['One two three.', + 'Second sentence.', + 'Third Sentence, test1, test2; test3-test4 .', + 'GO!', + 'Trust me,.', + 'Sentence with dots.', + 'Word, Word, word,.'] + assert tokens == expected_tokens, print('tokens:', tokens) + + +def test_flair_clause_extractor_close_type_none(example_sentence, chunk_english_tagger): + clause_extractor = FlairClauseExtractor( + tagger=chunk_english_tagger, clause_type=None) + tokens, separator = clause_extractor.extract_tokens(example_sentence) + assert isinstance(tokens, list) + assert isinstance(separator, str) + expected_tokens = ['One two three', + 'Second sentence.Third Sentence', + 'test1', + 'test2', + 'test3-test4', + 'GO', + 'Trust', + 'me', + 'sentence', + 'not starting', + 'with', + 'capital letter', + 'Sentence', + 'with', + 'dots.', + 'Word', + 'Word, word,...word', + 'a sentence'] + assert tokens == expected_tokens + + +def test_flair_clause_extractor_close_type_NP(example_sentence, chunk_english_tagger): + clause_extractor = FlairClauseExtractor( + tagger=chunk_english_tagger, clause_type=['NP']) + tokens, separator = clause_extractor.extract_tokens(example_sentence) + assert isinstance(tokens, list) + assert isinstance(separator, str) + expected_tokens = ['One two three', + 'Second sentence.Third Sentence', + 'test1', + 'test2', + 'test3-test4', + 'me', + 'sentence', + 'capital letter', + 'Sentence', + 'dots.', + 'Word', + 'Word, word,...word', + 'a sentence'] + assert tokens == expected_tokens + + +def test_clause_extractor_close_type_ADJP(example_sentence, chunk_english_tagger): + clause_extractor = FlairClauseExtractor( + tagger=chunk_english_tagger, clause_type=['ADJP']) + tokens, separator = clause_extractor.extract_tokens(example_sentence) + assert isinstance(tokens, list) + assert isinstance(separator, str) + print(tokens) + expected_tokens = [] + assert tokens == expected_tokens + + +def test_spacy_clause_extractor_close_type_none(example_sentence, web_sm_pipeline): + clause_extractor = SpacyClauseExtractor( + pipeline=web_sm_pipeline, clause_type=None) + tokens, separator = clause_extractor.extract_tokens(example_sentence) + assert isinstance(tokens, list) + assert isinstance(separator, str) + expected_tokens = ['One', 'two', 'three', '.', 'Second', 'sentence', '.', + 'Third', 'Sentence', ',', 'test1', ',', 'test2', ';', + 'test3', '-', 'test4', '.GO', '!', 'Trust', 'me', ',', + '.', 'sentence', 'not', 'starting', 'with', 'capital', + 'letter', '.', 'Sentence', 'with', 'dots', '..', 'Word', + ',', 'Word', ',', 'word,', '...', 'word', '....', 'so', 'a', 'sentence'] + assert tokens == expected_tokens + + +def test_spacy_clause_extractor_close_type_NN(example_sentence, web_sm_pipeline): + clause_extractor = SpacyClauseExtractor( + pipeline=web_sm_pipeline, clause_type=['NN']) + tokens, separator = clause_extractor.extract_tokens(example_sentence) + assert isinstance(tokens, list) + assert isinstance(separator, str) + expected_tokens = ['sentence', 'test1', 'test3', 'test4', 'sentence', + 'capital', 'letter', 'Sentence', '...', 'word', 'sentence'] + + assert tokens == expected_tokens + + +def test_extractor_factory(): + word_extractor = ExtractorFactory.get_extractor(extract_fct="word") + assert isinstance(word_extractor, WordExtractor) + + sentence_extractor = ExtractorFactory.get_extractor(extract_fct="sentence") + assert isinstance(sentence_extractor, SentenceExtractor) + + excerpt_extractor = ExtractorFactory.get_extractor(extract_fct="excerpt") + assert isinstance(excerpt_extractor, ExcerptExtractor) + + flair_clause_extractor = ExtractorFactory.get_extractor( + extract_fct="flair_clause", clause_type=['NP'], tagger=None) + assert isinstance(flair_clause_extractor, FlairClauseExtractor) + + spacy_clause_extractor = ExtractorFactory.get_extractor( + extract_fct="spacy_clause", clause_type=['NP'], pipeline=None) + assert isinstance(spacy_clause_extractor, SpacyClauseExtractor) + + with pytest.raises(ValueError): + ExtractorFactory.get_extractor(extract_fct="invalid") diff --git a/xplique/attributions/__init__.py b/xplique/attributions/__init__.py index 2b5169f3..ae6581de 100644 --- a/xplique/attributions/__init__.py +++ b/xplique/attributions/__init__.py @@ -16,4 +16,5 @@ from .object_detector import BoundingBoxesExplainer from .global_sensitivity_analysis import SobolAttributionMethod, HsicAttributionMethod from .gradient_statistics import SmoothGrad, VarGrad, SquareGrad +from .nlp_occlusion import NlpOcclusion from . import global_sensitivity_analysis diff --git a/xplique/attributions/grad_cam_pp.py b/xplique/attributions/grad_cam_pp.py index 485c0e3b..15ea3d07 100644 --- a/xplique/attributions/grad_cam_pp.py +++ b/xplique/attributions/grad_cam_pp.py @@ -40,10 +40,6 @@ class GradCAMPP(GradCAM): If a string is provided it will look for the layer name. """ - # Avoid zero division during procedure. (the value is not important, as if the denominator is - # zero, then the nominator will also be zero). - EPSILON = tf.constant(1e-4) - @staticmethod @tf.function def _compute_weights(feature_maps_gradients: tf.Tensor, @@ -71,7 +67,11 @@ def _compute_weights(feature_maps_gradients: tf.Tensor, nominator = feature_maps_gradients_square denominator = 2.0 * feature_maps_gradients_square + \ feature_maps_gradients_cube * feature_map_avg - denominator += tf.cast(denominator == 0, tf.float32) * GradCAMPP.EPSILON + + # Avoid zero division during procedure. (the value is not important, as if the denominator is + # zero, then the nominator will also be zero). + EPSILON = tf.constant(1e-4) + denominator += tf.cast(denominator == 0, tf.float32) * EPSILON feature_map_alphas = nominator / denominator * tf.nn.relu(feature_maps_gradients) weights = tf.reduce_mean(feature_map_alphas, axis=(1, 2)) diff --git a/xplique/attributions/nlp_occlusion.py b/xplique/attributions/nlp_occlusion.py new file mode 100644 index 00000000..29261459 --- /dev/null +++ b/xplique/attributions/nlp_occlusion.py @@ -0,0 +1,106 @@ +""" +Module related to Occlusion sensitivity method for NLP. +""" + +import numpy as np + +from .base import BlackBoxExplainer +from ..commons import Tasks +from ..types import Callable, Union, Optional, OperatorSignature, List + +class NlpOcclusion(BlackBoxExplainer): + """ + Occlusion class for NLP. + """ + def __init__(self, + model: Callable, + batch_size: Optional[int] = 32, + operator: Optional[Union[Tasks, str, OperatorSignature]] = None): + super().__init__(model, batch_size, operator) + + @staticmethod + def _get_masks(input_len: int) -> np.ndarray: + """ + Generate occlusion masks for a given input length. + + Parameters + ---------- + input_len : int + The length of the input for which occlusion masks are generated. + Typically it will be the number of words of a sentence. + + Returns + ------- + occlusion_masks : np.ndarray + The boolean occlusion masks, an identity matrix with False for the main diagonal. + This kind of mask can be used to generate n sentences, + each with a single distinct word removed. + """ + return np.eye(input_len) == 0 + + @staticmethod + def _apply_masks(words: List[str], masks: np.ndarray) -> np.ndarray: + """ + Apply occlusion masks to a list of words. + + Parameters + ---------- + words : List[str] + The list of words to which occlusion masks are applied. + masks : np.ndarray + The boolean occlusion masks to be applied. + + Returns + ------- + occluded_words : np.ndarray + The list of words with occlusion masks applied. + """ + perturbated_words = [np.array(words)[mask].tolist() for mask in masks] + return perturbated_words + + def explain(self, + sentence: str, + words: List[str], + separator: str) -> np.ndarray: + """ + Generate an explanation for the input sentence, by providing the importance of each word. + The importance will be computed by successively occluding each word of the sentence and + studying the impact of this occlusion on the model results. + + Parameters + ---------- + sentence : str + The input sentence for which an explanation is generated. + words : List[str] + List of words used to generate the explanation. These words must be part of + the input sentence, the importance will be computed on this list of words + (i.e some words of the original sentence can be omited this way). + separator : str + The separator used to join the words after the occlusion step, so a full + sentence can be fed to the model. + + Returns + ------- + explanation : np.ndarray + The generated explanation of format (nb_concepts, nb_words). + """ + + # generate n sentences with a different word masked (removed) each time + masks = NlpOcclusion._get_masks(len(words)) + perturbated_words = NlpOcclusion._apply_masks(words, masks) + + perturbated_sentences = [sentence] + perturbated_sentences.extend( + [separator.join(perturbated_word) for perturbated_word in perturbated_words]) + + # transform the perturbated reviews into their concept representation + # u_values has shape: ((W+1) x C) + u_values = self.model(perturbated_sentences) + + # Compute sensitivities: importances = u_value of the whole sentence - u_value of each word + whole_sentence_uvalues = u_values[0,:] + words_uvalues = u_values[1:,:] + l_importances = (whole_sentence_uvalues - words_uvalues).transpose() + l_importances /= (np.max(np.abs(l_importances)) + 1e-5) + + return l_importances diff --git a/xplique/attributions/rise.py b/xplique/attributions/rise.py index 64afaa8a..424bd05b 100644 --- a/xplique/attributions/rise.py +++ b/xplique/attributions/rise.py @@ -42,10 +42,6 @@ class Rise(BlackBoxExplainer): Value used as when applying masks. """ - # Avoid zero division during procedure. (the value is not important, as if the denominator is - # zero, then the nominator will also be zero). - EPSILON = tf.constant(1e-4) - def __init__(self, model: Callable, batch_size: Optional[int] = 32, @@ -113,7 +109,10 @@ def explain(self, rise_nominator += tf.reduce_sum(predictions * masks_upsampled, 0) rise_denominator += tf.reduce_sum(masks_upsampled, 0) - rise_map = rise_nominator / (rise_denominator + Rise.EPSILON) + # Avoid zero division during procedure. (the value is not important, as if the denominator is + # zero, then the nominator will also be zero). + EPSILON = tf.constant(1e-4) + rise_map = rise_nominator / (rise_denominator + EPSILON) rise_map = rise_map[tf.newaxis] rise_maps = rise_map if rise_maps is None else tf.concat([rise_maps, rise_map], axis=0) diff --git a/xplique/commons/__init__.py b/xplique/commons/__init__.py index 94237f90..cea752a4 100644 --- a/xplique/commons/__init__.py +++ b/xplique/commons/__init__.py @@ -4,10 +4,16 @@ from .data_conversion import tensor_sanitize, numpy_sanitize from .model_override import guided_relu_policy, deconv_relu_policy, override_relu_gradient, \ - find_layer, open_relu_policy + find_layer, open_relu_policy from .tf_operations import repeat_labels, batch_tensor from .callable_operations import predictions_one_hot_callable from .operators_operations import (Tasks, get_operator, check_operator, operator_batching, get_inference_function, get_gradient_functions) from .exceptions import no_gradients_available, raise_invalid_operator from .forgrad import forgrad +try: + from .nlp import (TokenExtractor, WordExtractor, SentenceExtractor, FlairClauseExtractor, + ExcerptExtractor, ExtractorFactory) + from .torch_operations import batcher, nlp_batch_predict, NlpPreprocessor +except ModuleNotFoundError: + pass diff --git a/xplique/commons/nlp.py b/xplique/commons/nlp.py new file mode 100644 index 00000000..9d21e56e --- /dev/null +++ b/xplique/commons/nlp.py @@ -0,0 +1,288 @@ +""" +NLP Util Classes +""" + +from abc import ABC, abstractmethod +import re +from typing import Type + +from nltk.tokenize import word_tokenize, sent_tokenize + +from flair.models import SequenceTagger +from flair.data import Sentence + +from spacy import Language + +from ..types import Union, List, Tuple + + +class TokenExtractor(ABC): + """ + Base class for the token extractors. + + Parameters + ---------- + language + The language of the sentence. + """ + + def __init__(self, language: str = "english"): + self.separator = " " + self.language = language + + @abstractmethod + def extract_tokens(self, sentence: Union[List[str], str]) -> List[str]: + """ + Extract tokens from a sentence. + + Parameters + ---------- + sentence + The input sentence, or a list of input sentences. + + Returns + ------- + tokens + A list of extracted tokens. + """ + + +class WordExtractor(TokenExtractor): + """ + Uses NLTK word tokenizer. + """ + + def extract_tokens(self, sentence: Union[List[str], str]) -> List[str]: + """ + Extract tokens from a sentence, using the NLTK word tokenizer + + Parameters + ---------- + sentence + The input sentence, or a list of input sentences. + + Returns + ------- + tokens + A list of extracted words. + """ + if not isinstance(sentence, str): + sentence = '.'.join(sentence) + words = word_tokenize(sentence) + + return words, self.separator + + +class SentenceExtractor(TokenExtractor): + """ + Uses NLTK sentence tokenizer. + + Parameters + ---------- + language + The language of the sentence. + """ + + def __init__(self, language: str = "english"): + super().__init__(language=language) + self.separator = ". " + + def extract_tokens(self, sentence: Union[List[str], str]) -> List[str]: + """ + Extract tokens from a sentence, using the NLTK sentence tokenizer. + + Parameters + ---------- + sentence + The input sentence, or a list of input sentences. + + Returns + ------- + tokens + A list of extracted sentences. + """ + if not isinstance(sentence, str): + sentence = '.'.join(sentence) + words = sent_tokenize(sentence, self.language) + return words, self.separator + + +class ExcerptExtractor(TokenExtractor): + """ + Uses a custom excerpt tokenizer. + """ + + def extract_tokens(self, sentence: Union[List[str], str]) -> List[str]: + """ + Extract excerpts from a sentence: + - split the input string with '.', '?', and '!' separators + - ignore excerpts not starting with a capital letter + + Parameters + ---------- + sentence + The input sentence, or a list of input sentences. + + Returns + ------- + tokens + A list of extracted excerpts. + """ + if not isinstance(sentence, str): + sentence = '.'.join(sentence) + # Split the input text with '.', '?' and '!' separators. + regexp = r"[A-Z][^.?!]*[.?!]" + res = re.findall(regexp, sentence) + excerpt_dataset = [st.strip() for st in res] + return excerpt_dataset, self.separator + + +class FlairClauseExtractor(TokenExtractor): + """ + Uses Flair sentence tokenizer. + + Parameters + ---------- + tagger + A trained Flair SequenceTagger. + clause_type + A list with the types of clauses to keep. Each clause shall be a string + corresponding to one of Flair SequenceTagger dictionnary. + Example: clause_type = ['NP', 'ADJP'] + If None, all clauses are kept. + language + The language of the sentence. + """ + + def __init__(self, + tagger: Type[SequenceTagger], + clause_type: List[str] = None, + language: str = "english"): + super().__init__(language=language) + self.clause_type = clause_type + self.tagger = tagger + + def extract_tokens(self, sentence: Union[List[str], str]) -> Tuple[List[str], str]: + """ + Separates the input texts into clauses, and only keeps the ones belonging to + the specified types. + If clause_type is None, the texts are split but all the clauses are kept. + + Parameters + ---------- + sentence + A list of strings that we wish to separate into clauses. + + Returns + ------- + clause_list + A list with input texts split into clauses. + separator + The separator used for spliting the sentence. + """ + if not isinstance(sentence, str): + sentence = '.'.join(sentence) + sentence = Sentence(sentence) + self.tagger.predict(sentences=sentence) + clause_list = [] + for segment in sentence.get_labels(): + if self.clause_type is None: + clause_list.append(segment.data_point.text) + elif segment.value in self.clause_type: + clause_list.append(segment.data_point.text) + + return clause_list, self.separator + + +class SpacyClauseExtractor(TokenExtractor): + """ + Uses Spacy sentence tokenizer. + + Parameters + ---------- + pipeline + A trained Spacy pipeline. + clause_type + A list with the types of clauses to keep. Each clause shall be a string + corresponding to one of Spacy Tag dictionnary. + Example: clause_type = ['NNP', 'VBZ'] + If None, all clauses are kept. + """ + + def __init__(self, + pipeline: Type[Language], + clause_type: List[str] = None): + super().__init__(language=None) + self.clause_type = clause_type + self.pipeline = pipeline + + def extract_tokens(self, sentence: Union[List[str], str]) -> Tuple[List[str], str]: + """ + Separates the input texts into clauses, and only keeps the ones belonging to + the specified types. + If clause_type is None, the texts are split but all the clauses are kept. + + Parameters + ---------- + sentence + A list of strings that we wish to separate into clauses. + + Returns + ------- + clause_list + A list with input texts split into clauses. + separator + The separator used for spliting the sentence. + """ + if not isinstance(sentence, str): + sentence = '.'.join(sentence) + + clause_list = [] + for token in self.pipeline(sentence): + if self.clause_type is None: + clause_list.append(token.text) + elif token.tag_ in self.clause_type: + clause_list.append(token.text) + + return clause_list, self.separator + + +class ExtractorFactory(): + """ + Factory for extractor classes. + """ + @staticmethod + def get_extractor(extract_fct="sentence", tagger=None, pipeline=None, clause_type=None): + """ + Get an instance of an extractor based on the specified extraction function. + + Parameters + ---------- + extract_fct + The type of extraction function to use, either 'word', 'sentence', 'excerpt', + 'flair_clause', 'spacy_clause'. + Default is "sentence". + tagger + A Flair SequenceTagger to use for Flair clause extractor. + pipeline + A Spacy Pipeline to use for Spacy clause extractor. + clause_type + Additional parameter specifying the type of clause (default is None). + + Returns + ------- + Extractor + An instance of the selected extractor based on the specified function. + """ + if extract_fct == "flair_clause": + return FlairClauseExtractor(tagger, clause_type) + if extract_fct == "spacy_clause": + return SpacyClauseExtractor(pipeline, clause_type) + if extract_fct == "sentence": + return SentenceExtractor() + if extract_fct == "word": + return WordExtractor() + if extract_fct == "excerpt": + return ExcerptExtractor() + raise ValueError("Extraction function can be only 'clause', \ + 'sentence', 'word' or 'excerpt'") diff --git a/xplique/commons/operators.py b/xplique/commons/operators.py index 84217df9..0c10fcd0 100644 --- a/xplique/commons/operators.py +++ b/xplique/commons/operators.py @@ -7,7 +7,7 @@ import tensorflow as tf from ..types import Callable, Optional -from ..utils_functions.object_detection import _box_iou, _format_objects, _EPSILON +from ..utils_functions.object_detection import _box_iou, _format_objects @tf.function @@ -197,6 +197,7 @@ def batch_loop(args): class_refs = tf.repeat(tf.expand_dims(class_refs, axis=1), repeats=size, axis=1) # (nb_box_ref, nb_box_pred) + _EPSILON = tf.constant(1e-4) classification_score = tf.reduce_sum(class_refs * classification, axis=-1) \ / (tf.norm(classification, axis=-1) * tf.norm(class_refs, axis=-1)+ _EPSILON) diff --git a/xplique/commons/torch_operations.py b/xplique/commons/torch_operations.py new file mode 100644 index 00000000..2dd0c431 --- /dev/null +++ b/xplique/commons/torch_operations.py @@ -0,0 +1,150 @@ +""" +Custom pytorch operations +""" + +from abc import ABC, abstractmethod +from math import ceil +import numpy as np +import torch + +from ..types import Union, Tuple, Callable, List, Dict + + +class NlpPreprocessor(ABC): + """ + Abstract base class for NLP preprocessing. + + Parameters + ---------- + tokenizer + The tokenizer function to be used for tokenization. It must be + a callable that returns outputs matching the model inputs expectations. + device + The device on which to perform tokenization (default is 'cuda'). + kwargs + A list of key value arguments that will be passed to the tokenizer when + tokenizing inputs. + """ + + def __init__(self, tokenizer: Callable, device='cuda', **kwargs): + self.tokenizer = tokenizer + self.device = device + self.tokenizer_kwargs = kwargs + + def tokenize(self, + samples: Union[List[str], np.ndarray]) -> Dict[str, torch.Tensor]: + """ + Transforms a list of strings into tokens for consumption by the transformer model. + + Parameters + ---------- + samples + A list of input strings to be tokenized. + + Returns + ------- + tokenized + Result of the tokenizer operation over the input sequences. Usually a dictionary + with keys 'input_ids' and 'attention_mask'. + """ + tokenized = self.tokenizer( + list(samples), + **self.tokenizer_kwargs + ).to(self.device) + + return tokenized + + @abstractmethod + def preprocess(self, + inputs: List[str], + labels: List[str]) -> Tuple[np.ndarray, np.ndarray]: + """ + Pre-process the dataset and adapt it according to a specific task. + + Parameters + ---------- + inputs + The input text data. + labels + The corresponding labels. + + Returns: + ---------- + preprocessed_inputs + The pre-processed inputs. + preprocessed_labels + The pre-processed outputs. + """ + raise NotImplementedError + + +def batcher(elements, batch_size: int): + """ + An function to create batches from a list of elements. + + Parameters + ---------- + elements + The list of elements to be batched. + batch_size + The size of each batch. + + Returns + ------ + batch + A batch of elements (yielded). + """ + nb_batchs = ceil(len(elements) / batch_size) + + for batch_i in range(nb_batchs): + batch_start = batch_i * batch_size + batch_end = batch_start + batch_size + + batch = elements[batch_start:batch_end] + yield batch + + +def nlp_batch_predict(model: Union[torch.nn.Module, Callable], + preprocessor: NlpPreprocessor, + inputs: List[str], + labels: List[str], + batch_size: int = 64) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Pre-processes and predicts using the transformer model in batches. + + Parameters + ---------- + model + The transformer model for prediction. + preprocessor + An instance of NlpPreprocessor used for preprocessing the input texts. + inputs + A list of n_samples input texts to be predicted. + labels + A list of labels corresponding to the input texts. + batch_size + The batch size (default is 64). + + Returns + ------- + predictions + The predictions output of the model for the pre-processed input texts. + processed_labels + The pre-processed labels. + """ + predictions = None + processed_labels = None + + with torch.no_grad(): + dataset = list(zip(inputs, labels)) + for batch in batcher(dataset, batch_size): + batch_inputs, batch_labels = zip(*batch) + x_preprocessed, y_preprocessed = preprocessor.preprocess( + np.array(batch_inputs), np.array(batch_labels)) + out_batch = model(**x_preprocessed) + predictions = out_batch if predictions is None else torch.cat( + [predictions, out_batch]) + processed_labels = y_preprocessed if processed_labels is None else torch.cat([ + processed_labels, y_preprocessed]) + + return predictions, processed_labels diff --git a/xplique/concepts/__init__.py b/xplique/concepts/__init__.py index 67a100a6..46d7b23a 100644 --- a/xplique/concepts/__init__.py +++ b/xplique/concepts/__init__.py @@ -8,5 +8,7 @@ from .craft_tf import CraftTf, CraftManagerTf try: from .craft_torch import CraftTorch, CraftManagerTorch + from .cockatiel import CockatielTorch + from .cockatiel_manager import CockatielManagerTorch except ImportError: pass diff --git a/xplique/concepts/cockatiel.py b/xplique/concepts/cockatiel.py new file mode 100644 index 00000000..598ab6fe --- /dev/null +++ b/xplique/concepts/cockatiel.py @@ -0,0 +1,569 @@ +""" +COCKATIEL Module for Pytorch +""" + +from typing import Callable, List, Tuple, Optional, Union, Type +from math import ceil + +import html +from IPython.display import display, HTML + +import numpy as np +import torch + +from xplique.attributions.base import BlackBoxExplainer +from ..commons import TokenExtractor, WordExtractor, ExcerptExtractor, NlpPreprocessor +from .craft import MaskSampler +from .craft_torch import BaseCraftTorch + + +class BaseCockatiel(BaseCraftTorch): + """ + A class implementing COCKATIEL, the concept based explainability method for NLP introduced + in https://arxiv.org/abs/2305.06754 + + Parameters + ---------- + input_to_latent_model + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + Must be a Pytorch model accepting tokenized inputs. + latent_to_logit_model + The second part of the model taking activation and returning + logits, h(.) in the original paper. + Must be a Pytorch model. + preprocessor + A callable object to transform strings into inputs for the model. + Embeds a tokenizer that will be used to feed the inputs to the model. + number_of_concepts + The number of concepts to extract. Default is 25. + batch_size + The batch size for all the operations that use the model. Default is 256. + device + The type of device on which to place the torch tensors + """ + + def __init__( + self, + input_to_latent_model: Callable, + latent_to_logit_model: Callable, + preprocessor: NlpPreprocessor, + number_of_concepts: int = 25, + batch_size: int = 256, + device: str = 'cuda' + ): + super().__init__(input_to_latent_model, latent_to_logit_model, + number_of_concepts, batch_size, device=device) + self.preprocessor = preprocessor + self.patch_extractor = ExcerptExtractor() + + def _latent_predict(self, inputs: List[str], resize=None) -> torch.Tensor: + """ + Compute the embedding space using the 1st model `input_to_latent_model`. + + Parameters + ---------- + inputs + A list of input string data. + resize + Unused parameter. + + Returns + ------- + activations + The latent activations of shape (n_samples, channels) + """ + def gen_activations(): + with torch.no_grad(): + nb_batchs = ceil(len(inputs) / self.batch_size) + for batch_id in range(nb_batchs): + batch_start = batch_id * self.batch_size + batch_end = batch_start + self.batch_size + batch_tokenized = self.preprocessor.tokenize( + samples=inputs[batch_start:batch_end]) + + batch_activations = self.input_to_latent_model( + **batch_tokenized) + yield batch_activations + + activations = torch.cat(list(gen_activations()), 0) + return activations + + def _preprocess(self, activations: torch.Tensor) -> np.ndarray: + """ + Preprocesses the activations to make sure that they're the right shape + for being input to the NMF algorithm later. + + Parameters + ---------- + activations + The (non-negative) activations from the model under study. + + Returns + ------- + activations + The preprocessed activations, ready for COCKATIEL. + """ + + # pylint disable=no-member + assert torch.min(activations) >= 0.0, "Activations must be positive." + + if len(activations.shape) == 4: + activations = torch.mean(activations, dim=(1, 2)) + + return self._to_np_array(activations) + + def _extract_patches(self, inputs: List[str]) -> Tuple[List[str], np.ndarray]: + """ + Extract patches (excerpts) from the input sentences, and compute their embeddings. + + Parameters + ---------- + inputs + Input sentences. + + Returns + ------- + patches + A list of excerpts (n_patches). + activations + The excerpts activations (n_patches, channels). + """ + crops, _ = self.patch_extractor.extract_tokens(inputs) + + activations = self._latent_predict(crops) + + # applying GAP(.) on the activation and ensure positivity if needed + activations = self._preprocess(activations) + + return crops, activations + + def estimate_importance(self, + inputs: np.ndarray = None, + sampler: MaskSampler = MaskSampler.SOBOL, + nb_design: int = 32, + cmaps: Optional[Union[Tuple, str]] = None) -> np.ndarray: + """ + Estimates the importance of each concept for a given class, either globally + on the whole dataset provided in the fit() method (in this case, inputs shall + be set to None), or locally on a specific input sentence. + + Parameters + ---------- + inputs : numpy array or Tensor + The input data on which to compute the importances. + If None, then the inputs provided in the fit() method + will be used (global importance of the whole dataset). + Default is None. + sampler + The sampling method to use for masking. Defaults to MaskSampler.SOBOL. + nb_design + The number of design to use for the importance estimation. Default is 32. + cmaps + The list of colors associated with each concept. + Can be either: + - A list of (r, g, b) colors to use as a base for the colormap. + - A colormap name compatible with `plt.get_cmap(cmap)`. + + Returns + ------- + importances + The Sobol total index (importance score) for each concept. + + """ + return super().estimate_importance(inputs=inputs, sampler=sampler, + nb_design=nb_design, cmaps=cmaps) + + def get_best_excerpts_per_concept(self, + nb_excerpts: int = 10, + nb_most_important_concepts: int = None) -> List[str]: + """ + Return the best excerpts for each concept. + + Parameters + ---------- + nb_excerpts + The number of excerpts (patches) to fetch per concept. Defaults to 10. + nb_most_important_concepts + The number of concepts to display. If provided, only display + nb_most_important_concepts, otherwise display them all. + Default is None. + + Returns + ------- + best_excerpts_per_concept + The list of the best excerpts per concept + """ + best_excerpts_per_concept = [] + for _, _, best_crops in \ + self._gen_best_concepts_crops(nb_excerpts, nb_most_important_concepts): + best_excerpts_per_concept.append(best_crops) + return best_excerpts_per_concept + + +class CockatielNlpVisualizationMixin(): + """ + Class containing text visualization methods for Cockatiel. + """ + + def display_concepts_excerpts(self, + nb_patchs: int = 10, + nb_most_important_concepts: int = None) -> None: + """ + Display the best excerpts for each concept. + + Parameters + ---------- + nb_patchs + The number of patches to display per concept. Defaults to 10. + nb_most_important_concepts + The number of concepts to display. If provided, only display + nb_most_important_concepts, otherwise display them all. + Default is None. + """ + for c_id, c_id_importance, best_crops in \ + self._gen_best_concepts_crops(nb_patchs, nb_most_important_concepts): + + print(f"Concept {c_id} has an importance value of " + f"{c_id_importance:.2f}") + for crop in best_crops: + print(f"\t{crop}") + + def plot_concept_attribution_maps(self, + sentences: List[str], + token_extractor: TokenExtractor, + explainer_class: Type[BlackBoxExplainer], + ignore_words: List[str] = None, + importances: np.ndarray = None, + nb_most_important_concepts: int = 5, + title: str = "", + css_namespace: str = "", + filter_percentile: int = 80, + display_width: int = 400) -> List[str]: + """ + Display the concepts attribution maps for the sentences given in argument. + + Parameters + ---------- + sentences + The list of sentences for which attribution maps will be displayed. + token_extractor + The token extractor used to extract tokens from the sentences. + explainer_class + Explainer class used during the masking process. Typically a NlpOcclusion class. + ignore_words + Words to ignore during the occlusion process. These will not be part of the + extracted concepts. + importances + The importances computed by the estimate_importance() method. + If None is provided, then the global importances will be used, otherwise + the local importances set in this parameter will be used. + nb_most_important_concepts + Number of most important concepts to consider. Default is 5. + title + The title to use when displaying the results. + css_namespace + A namespace to be prefixed to the css class in order to avoid collisions. + Usefull when running several instances of Cockatiel, to prevent having the + same color for all the concepts of the same name. + filter_percentile + Percentile used to filter the concept heatmap when displaying the HTML + sentences (only show concept if excess N-th percentile). Default is 80. + display_width + Width of the displayed HTML text. Default is 400. + + Returns + ------- + html_lines + List of HTML text lines. + """ + if importances is None: + # global + most_important_concepts = self.sensitivity.most_important_concepts + else: + # local + most_important_concepts = np.argsort(importances)[::-1] + + most_important_concepts = most_important_concepts[:nb_most_important_concepts] + + # Build the CSS + concepts_css_names = [ + f'{css_namespace}_{concept_id}' for concept_id in most_important_concepts] + + css = CockatielNlpVisualizationMixin.build_concepts_css( + concepts_names=concepts_css_names, + cmaps=self.sensitivity.cmaps[:nb_most_important_concepts]) + html_output = [css] + + # Build the title + html_output.append(f"{title}") + + # Build the legend + html_output.append( + f"
Legend :
" + f"{CockatielNlpVisualizationMixin.get_legend(concepts_css_names)}
") + + # Generate the HTML for each sentence + html_output.append('') # list end + + # Display the HTML text & return it as well + display(HTML("".join(html_output))) + return html_output + + def sentence_attribution_map(self, + sentence: str, + token_extractor: WordExtractor, + explainer_class: Type[BlackBoxExplainer], + most_important_concepts: np.ndarray, + concepts_names: List[str] = None, + ignore_words: List[str] = None, + filter_percentile: int = 80, + display_width: int = 400) -> str: + """ + Compute the concepts attribution maps for the sentence given in argument, + and return the corresponding sentence as an HTML formatted sentence where + the words belonging to each concept are highlighted with colors. + + Parameters + ---------- + sentence + The sentence to process. + token_extractor + The token extractor used to extract tokens from the sentences. + explainer_class + Explainer class used during the masking process. Typically a NlpOcclusion class. + most_important_concepts + The concepts ids to display. + concepts_names + The concepts names to use for the HTML display, corresponding to + most_important_concepts. + ignore_words + Words to ignore during the occlusion process. These will not be part of the + extracted concepts. + filter_percentile + Percentile used to filter the concept heatmap when displaying the HTML + sentences (only show concept if excess N-th percentile). Default is 80. + display_width + Width of the displayed HTML text. Default is 400. + + Returns + ------- + html_lines + The sentence formatted as HTML text. + """ + extracted_words, separator = token_extractor.extract_tokens(sentence) + + # Filter words + if ignore_words is not None: + words = [word for word in extracted_words if word not in ignore_words] + else: + words = extracted_words + + explainer = explainer_class(model=self.transform) + l_importances = explainer.explain( + sentence=sentence, words=words, separator=separator) + + # Display only the most important concepts + l_importances = l_importances[most_important_concepts] + + return self.convert_sentence_to_html(extracted_words, separator, words, l_importances, + concepts_names=concepts_names, + filter_percentile=filter_percentile, + display_width=display_width) + + @staticmethod + def get_legend(concepts_names: List[str]) -> str: + """ + Generates an HTML legend for the concepts attribution maps. + + Parameters + ---------- + concepts_names + The list of concepts names that are handled by the current Cockatiel instance. + + Returns + ------- + html_legend + The legend formatted as HTML text. + """ + + # Extract the label name from the concept name, which are composed of [css_namespace]_[cid] + labels = [ + f"Concept {c_name.split('_')[-1]}" for c_name in concepts_names] + legend_importances = np.eye(len(concepts_names)) / 2.0 + return CockatielNlpVisualizationMixin.\ + convert_sentence_to_html(extracted_words=labels, + words=labels, + separator=" ", + concepts_names=concepts_names, + explanation=legend_importances) + + @staticmethod + def convert_sentence_to_html( + extracted_words: List[str], + separator: str, + words: List[str], + explanation: np.ndarray, + concepts_names: List[str], + filter_percentile: int = 80, + display_width: int = 400 + ) -> str: + """ + Generates the visualization for COCKATIEL's explanations. + + Parameters + ---------- + extracted_words + List of extracted words from the input sentence: it shall + be possible to reconstruct the whole input sentence using `extracted_words` + and `separator`. + separator + Separator used to join extracted_words to build the whole sentence. + words + List of words used to generate the explanation, these words should be part + of the sentence and are attached to a concept of the `explanation` parameter below. + explanation + An array that corresponds to the output of the occlusion function. + Has a shape of (nb_concepts x nb_words). + concepts_names + The list of concepts names that are handled by the current Cockatiel instance. + filter_percentile + Percentile used to filter the concept heatmap when displaying the HTML + sentences (only show concept if excess N-th percentile). Default is 80. + display_width + Width of the displayed HTML text. Default is 400. + + Returns + ------- + html_lines + The sentence formatted as HTML text. + """ + l_phi = np.array(explanation) + + # Filter the values that are below the percentile for each concept importance array + sigmas = [[np.percentile(phi, filter_percentile)] for phi in l_phi] + l_phi = l_phi * np.array(l_phi > sigmas, np.float32) + + phi_html = [] + + # Build a dictionary of structure 'word' => importance_value, concept_id + words_importances = dict( + zip(words, zip(l_phi.max(axis=0), l_phi.argmax(axis=0)))) + + i = 0 + for word in extracted_words: + if word in words_importances: + value, value_max_id = words_importances[word] + + # concept name is composed of css_namespace and c_id + concept_name = concepts_names[value_max_id] + c_id = concept_name.split('_')[-1] + + title = f"[{html.escape(word)}] concept:{c_id} importance:{value:.2f}" + + # add debug infos in the popup + # indices = np.nonzero(l_phi[:, i])[0] + # title += "\n DEBUG:" + # for j in indices: + # concept_name = concepts_names[j] + # title += f"\n concept:{concept_name} importance {l_phi[j,i]:.2f}" + + if value > 0: + phi_html.append(f"{word}{separator}") + else: + phi_html.append( + f"{word}{separator}") + + i += 1 + else: + # ignored words + phi_html.append(f" {word}{separator}") + + html_result = f"" \ + + " ".join(phi_html) + "" + return html_result + + @staticmethod + def build_concepts_css(concepts_names: List[str], cmaps: List[str]): + """ + Generates the CSS for HTML COCKATIEL's explanations. + + Parameters + ---------- + concepts_names + The list of concepts names. + cmaps + The color associated with each concept. + Can be either: + - A tuple (r, g, b) of color to use as a base for the colormap. + - A colormap name compatible with `plt.get_cmap(cmap)`. + """ + legend_css = """ + + """ + + concepts_css = [] + for c_id, cmap in zip(concepts_names, cmaps): + red, green, blue, alpha = [int(c*255) for c in cmap(1)] + txt = f"\n.concept{c_id} {{" \ + f" padding: 1px 5px;" \ + f" border: solid 3px;" \ + f" border-color: rgba({red}, {green}, {blue});" \ + f" border-radius: 10px;" \ + f" background-color: rgba({red}, {green}, {blue}, var(--opacity));" \ + f" --opacity: {alpha};" \ + f"}}" + + concepts_css.append(txt) + css_text = "" + + return legend_css + css_text + + +class CockatielTorch(BaseCockatiel, CockatielNlpVisualizationMixin): + """ + Class implementing Cockatiel on Pytorch, adapted for text visualization. + + Parameters + ---------- + input_to_latent_model + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + Must be a Pytorch model accepting tokenized inputs. + latent_to_logit_model + The second part of the model taking activation and returning + logits, h(.) in the original paper. + Must be a Pytorch model. + preprocessor + A callable object to transform strings into inputs for the model. + Embeds a tokenizer that will be used to feed the inputs to the model. + number_of_concepts + The number of concepts to extract. Default is 25. + batch_size + The batch size for all the operations that use the model. Default is 256. + device + The type of device on which to place the torch tensors + """ diff --git a/xplique/concepts/cockatiel_manager.py b/xplique/concepts/cockatiel_manager.py new file mode 100644 index 00000000..d2f8d5ac --- /dev/null +++ b/xplique/concepts/cockatiel_manager.py @@ -0,0 +1,256 @@ +""" +COCKATIEL MANAGER Module for Pytorch +""" + +from typing import Callable, List, Type +import numpy as np +import torch + +from xplique.attributions.base import BlackBoxExplainer + +from .craft_torch import BaseCraftManagerTorch +from .cockatiel import CockatielTorch +from ..commons import TokenExtractor, NlpPreprocessor +from ..commons.torch_operations import nlp_batch_predict + + +class BaseCockatielManagerTorch(BaseCraftManagerTorch): + """ + Base class implementing the CockatielManager on Pytorch. + This manager creates one Cockatiel instance per class to explain. + + Parameters + ---------- + input_to_latent_model + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + Must return positive activations. + latent_to_logit_model + The second part of the model taking activation and returning + logits, h(.) in the original paper. + preprocessor + A callable object to transform strings into inputs for the model. + inputs + Input data: a list of strings. + labels + Labels of the inputs. + list_of_class_of_interest + A list of the classes id to explain. The manager will instanciate one + CraftTorch object per element of this list. + number_of_concepts + The number of concepts to extract. Default is 20. + batch_size + The batch size to use during training and prediction. Default is 64. + device + The type of device on which to place the torch tensors + """ + + def __init__(self, input_to_latent_model: Callable, + latent_to_logit_model: Callable, + preprocessor: NlpPreprocessor, + inputs: List[str], + labels: List[str], + list_of_class_of_interest: list = None, + number_of_concepts: int = 20, + batch_size: int = 64, + device: str = 'cuda'): + + super().__init__(input_to_latent_model=input_to_latent_model, + latent_to_logit_model=latent_to_logit_model, + inputs=inputs, + labels=labels, + list_of_class_of_interest=list_of_class_of_interest) + self.preprocessor = preprocessor + self.batch_size = batch_size + self.device = device + for class_of_interest in self.list_of_class_of_interest: + craft = CockatielTorch(input_to_latent_model=input_to_latent_model, + latent_to_logit_model=latent_to_logit_model, + preprocessor=preprocessor, + number_of_concepts=number_of_concepts, + batch_size=batch_size, + device=device) + self.craft_instances[class_of_interest] = craft + + def compute_predictions(self) -> np.ndarray: + """ + Computes predictions using the input-to-latent and latent-to-logit models, + on the input data provided at the creation of the CockatielManager instance. + + Returns + ------- + np.ndarray + Predicted labels. + """ + features, _ = nlp_batch_predict( + self.input_to_latent_model, preprocessor=self.preprocessor, + inputs=self.inputs, labels=self.labels, batch_size=self.batch_size) + logits = self.latent_to_logit_model(features) + y_preds = np.array(torch.argmax(logits, -1).cpu().detach()) + return y_preds + + def get_best_excerpts_per_concept(self, + class_id: int, + nb_excerpts: int = 10, + nb_most_important_concepts: int = None) -> List[str]: + """ + Return the best excerpts for each concept. + + Parameters + ---------- + class_id + The class to explain. + nb_excerpts + The number of excerpts (patches) to fetch per concept. Defaults to 10. + nb_most_important_concepts + The number of concepts to display. If provided, only display + nb_most_important_concepts, otherwise display them all. + Default is None. + + Returns + ------- + best_excerpts_per_concept + The list of the best excerpts per concept + """ + return self.craft_instances[class_id]\ + .get_best_excerpts_per_concept(nb_excerpts, nb_most_important_concepts) + + +class CockatielManagerNlpVisualizationMixin: + """ + Class containing text visualization methods for CockatielManager. + """ + + def plot_concepts_importances(self, + class_id: int, + nb_most_important_concepts: int = 5, + verbose: bool = False): + """ + Plot a bar chart displaying the importance value of each concept. + + Parameters + ---------- + class_id + The class to explain. + nb_most_important_concepts + The number of concepts to focus on. Default is 5. + verbose + If True, then print the importance value of each concept, otherwise no textual + output will be printed. + """ + self.craft_instances[class_id]\ + .plot_concepts_importances(importances=None, + nb_most_important_concepts=nb_most_important_concepts, + verbose=verbose) + + def display_concepts_excerpts(self, + class_id: int, + nb_patchs: int = 10, + nb_most_important_concepts: int = None) -> None: + """ + Display the best excerpts for each concept. + + Parameters + ---------- + class_id + The class to explain. + nb_patchs + The number of patches to display per concept. Defaults to 10. + nb_most_important_concepts + The number of concepts to display. If provided, only display + nb_most_important_concepts, otherwise display them all. + Default is None. + """ + self.craft_instances[class_id]\ + .display_concepts_excerpts(nb_patchs=nb_patchs, + nb_most_important_concepts=nb_most_important_concepts) + + def plot_concept_attribution_maps(self, + class_id: int, + sentences: List[str], + token_extractor: TokenExtractor, + explainer_class: Type[BlackBoxExplainer], + ignore_words: List[str] = None, + importances: np.ndarray = None, + nb_most_important_concepts: int = 5, + title: str = "", + filter_percentile: int = 80, + display_width: int = 400) -> List[str]: + """ + Display the concepts attribution maps for the sentences given in argument. + + Parameters + ---------- + class_id + The class id to explain. + sentences + The list of sentences for which attribution maps will be displayed. + token_extractor + The token extractor used to extract tokens from the sentences. + explainer_class + Explainer class used during the masking process. Typically a NlpOcclusion class. + ignore_words + Words to ignore during the occlusion process. These will not be part of the + extracted concepts. + importances + The importances computed by the estimate_importance() method. + If None is provided, then the global importances will be used, otherwise + the local importances set in this parameter will be used. + nb_most_important_concepts + Number of most important concepts to consider. Default is 5. + title + The title to use when displaying the results. + filter_percentile + Percentile used to filter the concept heatmap when displaying the HTML + sentences (only show concept if excess N-th percentile). Default is 80. + display_width + Width of the displayed HTML text. Default is 400. + + Returns + ------- + html_lines + List of HTML text lines. + """ + self.craft_instances[class_id].plot_concept_attribution_maps( + sentences=sentences, + token_extractor=token_extractor, + explainer_class=explainer_class, + ignore_words=ignore_words, + importances=importances, + nb_most_important_concepts=nb_most_important_concepts, + title=title, + filter_percentile=filter_percentile, + display_width=display_width, + css_namespace=class_id) + + +class CockatielManagerTorch(BaseCockatielManagerTorch, CockatielManagerNlpVisualizationMixin): + """ + Class implementing CockatielManager on Pytorch, adapted for text visualization. + This manager creates one CockatielTorch instance per class to explain. + + Parameters + ---------- + input_to_latent_model + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + Must return positive activations. + latent_to_logit_model + The second part of the model taking activation and returning + logits, h(.) in the original paper. + preprocessor + A callable object to transform strings into inputs for the model. + inputs + Input data: a list of strings. + labels + Labels of the inputs of shape (n_samples, class_id) + list_of_class_of_interest + A list of the classes id to explain. The manager will instanciate one + CraftTorch object per element of this list. + number_of_concepts + The number of concepts to extract. Default is 20. + batch_size + The batch size to use during training and prediction. Default is 64. + device + The type of device on which to place the torch tensors + """ diff --git a/xplique/concepts/craft.py b/xplique/concepts/craft.py index dbfb5809..415ce152 100644 --- a/xplique/concepts/craft.py +++ b/xplique/concepts/craft.py @@ -7,6 +7,7 @@ from enum import Enum import colorsys from math import ceil +import random import numpy as np import cv2 @@ -16,12 +17,14 @@ from matplotlib.colors import ListedColormap, LinearSegmentedColormap from matplotlib import gridspec -from xplique.attributions.global_sensitivity_analysis import (HaltonSequenceRS, JansenEstimator) +from xplique.attributions.global_sensitivity_analysis import \ + (HaltonSequenceRS, ScipySobolSequenceRS, LatinHypercubeRS, JansenEstimator) from xplique.plots.image import _clip_percentile from ..types import Callable, Tuple, Optional, Union from .base import BaseConceptExtractor + @dataclasses.dataclass class Factorization: """ Dataclass handling data produced during the Factorization step.""" @@ -32,6 +35,7 @@ class Factorization: crops_u: np.ndarray concept_bank_w: np.ndarray + class Sensitivity: """ Dataclass handling data produced during the Sobol indices computation. @@ -53,14 +57,14 @@ class Sensitivity: """ def __init__(self, importances: np.ndarray, - most_important_concepts: np.ndarray, - cmaps: Optional[Union[list, str]]=None): + most_important_concepts: np.ndarray, + cmaps: Optional[Union[list, str]] = None): self.importances = importances self.most_important_concepts = most_important_concepts self.set_concept_attribution_cmap(cmaps=cmaps) @staticmethod - def _get_alpha_cmap(cmap: Union[Tuple,str]): + def _get_alpha_cmap(cmap: Union[Tuple, str]): """ Creat a colormap with an alpha channel, out of 3 (r, g, b) values. This is used in particular by `set_concept_attribution_cmap()` to @@ -93,12 +97,12 @@ def _get_alpha_cmap(cmap: Union[Tuple,str]): cmap = LinearSegmentedColormap.from_list("", [colors, cmax]) alpha_cmap = cmap(np.arange(256)) - alpha_cmap[:,-1] = np.linspace(0, 0.85, 256) + alpha_cmap[:, -1] = np.linspace(0, 0.85, 256) alpha_cmap = ListedColormap(alpha_cmap) return alpha_cmap - def set_concept_attribution_cmap(self, cmaps: Optional[Union[Tuple, str]]=None): + def set_concept_attribution_cmap(self, cmaps: Optional[Union[Tuple, str]] = None): """ Set the colormap used for the concepts displayed in the attribution maps. @@ -112,24 +116,33 @@ def set_concept_attribution_cmap(self, cmaps: Optional[Union[Tuple, str]]=None): """ if cmaps is None: self.cmaps = [ - Sensitivity._get_alpha_cmap((54, 197, 240)), - Sensitivity._get_alpha_cmap((210, 40, 95)), - Sensitivity._get_alpha_cmap((236, 178, 46)), - Sensitivity._get_alpha_cmap((15, 157, 88)), - Sensitivity._get_alpha_cmap((84, 25, 85)), - Sensitivity._get_alpha_cmap((55, 35, 235)) - ] + Sensitivity._get_alpha_cmap((54, 197, 240)), + Sensitivity._get_alpha_cmap((210, 40, 95)), + Sensitivity._get_alpha_cmap((236, 178, 46)), + Sensitivity._get_alpha_cmap((15, 157, 88)), + Sensitivity._get_alpha_cmap((84, 25, 85)), + Sensitivity._get_alpha_cmap((55, 35, 235)) + ] # Add more colors by default cmaps_more = [Sensitivity._get_alpha_cmap(cmap) - for cmap in plt.get_cmap('tab10').colors] + for cmap in plt.get_cmap('tab10').colors] self.cmaps.extend(cmaps_more) else: self.cmaps = [Sensitivity._get_alpha_cmap(cmap) for cmap in cmaps] if len(self.cmaps) < len(self.most_important_concepts): - raise RuntimeError(f'Not enough colors in cmaps ({len(self.cmaps)}) ' \ - f'compared to the number of important concepts ' \ - '({len(self.most_important_concepts)})') + nb_colors_missing = len( + self.most_important_concepts) - len(self.cmaps) + print(f'Not enough colors in cmaps ({len(self.cmaps)}) ' + f'compared to the number of important concepts ' + f'({len(self.most_important_concepts)}). ' + f'Adding {nb_colors_missing} random colors.') + + for _ in range(nb_colors_missing): + random_color = (random.random(), + random.random(), random.random()) + self.cmaps.append(Sensitivity._get_alpha_cmap(random_color)) + class DisplayImportancesOrder(Enum): """ @@ -139,11 +152,21 @@ class DisplayImportancesOrder(Enum): LOCAL: Order concepts by their Local importance on a single sample """ GLOBAL = 0 - LOCAL = 1 + LOCAL = 1 def __eq__(self, other): return self.value == other.value + +class MaskSampler(Enum): + HALTON = HaltonSequenceRS + SOBOL = ScipySobolSequenceRS + LATIN = LatinHypercubeRS + + def __eq__(self, other): + return self.value == other.value + + class BaseCraft(BaseConceptExtractor, ABC): """ Base class implementing the CRAFT Concept Extraction Mechanism. @@ -171,11 +194,11 @@ class BaseCraft(BaseConceptExtractor, ABC): The size of the patches (crops) to extract from the input data. Default is 64. """ - def __init__(self, input_to_latent_model : Callable, - latent_to_logit_model : Callable, - number_of_concepts: int = 20, - batch_size: int = 64, - patch_size: int = 64): + def __init__(self, input_to_latent_model: Callable, + latent_to_logit_model: Callable, + number_of_concepts: int = 20, + batch_size: int = 64, + patch_size: int = 64): super().__init__(number_of_concepts, batch_size) self.input_to_latent_model = input_to_latent_model self.latent_to_logit_model = latent_to_logit_model @@ -185,11 +208,10 @@ def __init__(self, input_to_latent_model : Callable, self.sensitivity = None # sanity checks - assert(hasattr(input_to_latent_model, "__call__")), \ - "input_to_latent_model must be a callable function" - assert(hasattr(latent_to_logit_model, "__call__")), \ - "latent_to_logit_model must be a callable function" - + assert (hasattr(input_to_latent_model, "__call__")), \ + "input_to_latent_model must be a callable function" + assert (hasattr(latent_to_logit_model, "__call__")), \ + "latent_to_logit_model must be a callable function" @abstractmethod def _latent_predict(self, inputs: np.ndarray): @@ -216,11 +238,13 @@ def check_if_fitted(self): If the factorization model has not been fitted to input data. """ if self.factorization is None: - raise NotFittedError("The factorization model has not been fitted to input data yet.") + raise NotFittedError( + "The factorization model has not been fitted to input data yet.") def fit(self, - inputs : np.ndarray, - class_id: int = 0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + inputs: np.ndarray, + class_id: int = 0, + alpha_w: float = 1e-2) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Fit the Craft model to the input data. @@ -231,6 +255,8 @@ def fit(self, (x1, x2, ..., xn) in the paper. class_id The class id of the inputs. + alpha_w + Constant that multiplies the NMF regularization terms of the concept banck (W in the paper). Returns ------- @@ -246,7 +272,7 @@ def fit(self, crops, activations = self._extract_patches(inputs) # apply NMF to the activations to obtain matrices U and W - reducer = NMF(n_components=self.number_of_concepts, alpha_W=1e-2) + reducer = NMF(n_components=self.number_of_concepts, alpha_W=alpha_w) crops_u = reducer.fit_transform(activations) concept_bank_w = reducer.components_.astype(np.float32) @@ -255,7 +281,7 @@ def fit(self, return crops, crops_u, concept_bank_w - def transform(self, inputs : np.ndarray, activations : np.ndarray = None) -> np.ndarray: + def transform(self, inputs: np.ndarray, activations: np.ndarray = None) -> np.ndarray: """Transforms the inputs data into its concept representation. Parameters @@ -289,10 +315,15 @@ def transform(self, inputs : np.ndarray, activations : np.ndarray = None) -> np. if is_4d: # (N * W * H, R) -> (N, W, H, R) with R = nb_concepts - coeffs_u = np.reshape(coeffs_u, (*original_shape, coeffs_u.shape[-1])) + coeffs_u = np.reshape( + coeffs_u, (*original_shape, coeffs_u.shape[-1])) return coeffs_u - def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) -> np.ndarray: + def estimate_importance(self, + inputs: np.ndarray = None, + sampler: MaskSampler = MaskSampler.HALTON, + nb_design: int = 32, + cmaps: Optional[Union[Tuple, str]] = None) -> np.ndarray: """ Estimates the importance of each concept for a given class, either globally on the whole dataset provided in the fit() method (in this case, inputs shall @@ -305,8 +336,15 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) - If None, then the inputs provided in the fit() method will be used (global importance of the whole dataset). Default is None. + sampler + The sampling method to use for masking. Default to MaskSampler.HALTON. nb_design The number of design to use for the importance estimation. Default is 32. + cmaps + The list of colors associated with each concept. + Can be either: + - A list of (r, g, b) colors to use as a base for the colormap. + - A colormap name compatible with `plt.get_cmap(cmap)`. Returns ------- @@ -323,7 +361,7 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) - coeffs_u = self.transform(inputs) - masks = HaltonSequenceRS()(self.number_of_concepts, nb_design = nb_design) + masks = sampler.value()(self.number_of_concepts, nb_design=nb_design) estimator = JansenEstimator() importances = [] @@ -348,9 +386,9 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) - for coeff in coeffs_u: u_perturbated = coeff[None, :] * masks[:, None, None, :] a_perturbated = np.reshape(u_perturbated, - (-1, coeff.shape[-1])) @ self.factorization.concept_bank_w + (-1, coeff.shape[-1])) @ self.factorization.concept_bank_w a_perturbated = np.reshape(a_perturbated, - (len(masks), coeffs_u.shape[1], coeffs_u.shape[2], -1)) + (len(masks), coeffs_u.shape[1], coeffs_u.shape[2], -1)) # a_perturbated: (N, H, W, C) y_pred = self._logit_predict(a_perturbated) @@ -365,14 +403,15 @@ def estimate_importance(self, inputs : np.ndarray = None, nb_design: int = 32) - # Save the results of the computation if working on the whole dataset if compute_global_importances: most_important_concepts = np.argsort(importances)[::-1] - self.sensitivity = Sensitivity(importances, most_important_concepts) + self.sensitivity = Sensitivity( + importances, most_important_concepts, cmaps=cmaps) return importances def plot_concepts_importances(self, importances: np.ndarray = None, - display_importance_order: DisplayImportancesOrder = \ - DisplayImportancesOrder.GLOBAL, + display_importance_order: DisplayImportancesOrder = + DisplayImportancesOrder.GLOBAL, nb_most_important_concepts: int = None, verbose: bool = False): """ @@ -422,7 +461,8 @@ def plot_concepts_importances(self, most_important_concepts = most_important_concepts[:nb_most_important_concepts] # Find the correct color index - global_color_index_order = np.argsort(self.sensitivity.importances)[::-1] + global_color_index_order = np.argsort( + self.sensitivity.importances)[::-1] local_color_index_order = [np.where(global_color_index_order == local_c)[0][0] for local_c in most_important_concepts] colors = np.array([colors(1.0) @@ -441,7 +481,8 @@ def plot_concepts_importances(self, if verbose: for c_id in most_important_concepts: - print(f"Concept {c_id} has an importance value of {importances[c_id]:.2f}") + print( + f"Concept {c_id} has an importance value of {importances[c_id]:.2f}") @staticmethod def _show(img, **kwargs): @@ -455,6 +496,46 @@ def _show(img, **kwargs): plt.imshow(img, **kwargs) plt.axis('off') + def _gen_best_concepts_crops(self, + nb_crops: int = 10, + nb_most_important_concepts: int = None) \ + -> Tuple[int, float, np.ndarray]: + """ + Generate the best concept crops for each concept. + + Parameters + ---------- + nb_crops : int + The number of crops (patches) to display per concept. Defaults to 10. + nb_most_important_concepts : int + The number of concepts to consider. If provided, only take into account + nb_most_important_concepts, otherwise use them all. + Default is None. + Returns + ------- + Tuple + A tuple containing: + - The current concept id. + - The overall importance score for this concept. + - An array containing the best crops for this concept. + """ + most_important_concepts = self.sensitivity.most_important_concepts + if nb_most_important_concepts is not None: + most_important_concepts = most_important_concepts[:nb_most_important_concepts] + + for c_id in most_important_concepts: + best_crops_ids = np.argsort(self.factorization.crops_u[:, c_id])[ + ::-1][:nb_crops] + best_crops = np.array(self.factorization.crops)[best_crops_ids] + c_id_importance = self.sensitivity.importances[c_id] + yield c_id, c_id_importance, best_crops + + +class CraftImageVisualizationMixin(): + """ + Class containing image visualization methods for Craft. + """ + def plot_concepts_crops(self, nb_crops: int = 10, nb_most_important_concepts: int = None, @@ -474,17 +555,11 @@ def plot_concepts_crops(self, If True, then print the importance value of each concept, otherwise no textual output will be printed. """ - most_important_concepts = self.sensitivity.most_important_concepts - if nb_most_important_concepts is not None: - most_important_concepts = most_important_concepts[:nb_most_important_concepts] - - for c_id in most_important_concepts: - best_crops_ids = np.argsort(self.factorization.crops_u[:, c_id])[::-1][:nb_crops] - best_crops = np.array(self.factorization.crops)[best_crops_ids] - + for c_id, c_id_importance, best_crops in \ + self._gen_best_concepts_crops(nb_crops, nb_most_important_concepts): if verbose: - print(f"Concept {c_id} has an importance value of " \ - f"{self.sensitivity.importances[c_id]:.2f}") + print(f"Concept {c_id} has an importance value of " + f"{c_id_importance:.2f}") plt.figure(figsize=(7, (2.5/2)*ceil(nb_crops/5))) for i in range(nb_crops): plt.subplot(ceil(nb_crops/5), 5, i+1) @@ -493,7 +568,7 @@ def plot_concepts_crops(self, def plot_concept_attribution_legend(self, nb_most_important_concepts: int = 6, - border_width: int=5): + border_width: int = 5): """ Plot a legend for the concepts attribution maps. @@ -511,13 +586,14 @@ def plot_concept_attribution_legend(self, cmap = self.sensitivity.cmaps[i] plt.subplot(1, len(most_important_concepts), i+1) - best_crops_id = np.argsort(self.factorization.crops_u[:, c_id])[::-1][0] + best_crops_id = np.argsort( + self.factorization.crops_u[:, c_id])[::-1][0] best_crop = self.factorization.crops[best_crops_id] if best_crop.shape[0] > best_crop.shape[-1]: - mask = np.zeros(best_crop.shape[:-1]) # tf + mask = np.zeros(best_crop.shape[:-1]) # tf else: - mask = np.zeros(best_crop.shape[1:]) # torch + mask = np.zeros(best_crop.shape[1:]) # torch mask[:border_width, :] = 1.0 mask[:, :border_width] = 1.0 mask[-border_width:, :] = 1.0 @@ -569,22 +645,23 @@ def compute_subplots_layout_parameters(images: np.ndarray, # get width and height of our images if images.shape[1] == 3: nb_samples = images.shape[0] - [l_width, l_height] = images.shape[2:4] # pytorch + [l_width, l_height] = images.shape[2:4] # pytorch else: - [nb_samples, l_width, l_height] = images.shape[0:3] # tf + [nb_samples, l_width, l_height] = images.shape[0:3] # tf rows = ceil(nb_samples / cols) # define the figure margin, width, height in inch figwidth = cols * img_size + (cols-1) * spacing + 2 * margin - figheight = rows * img_size * l_height/l_width + (rows-1) * spacing + 2 * margin + figheight = rows * img_size * l_height / \ + l_width + (rows-1) * spacing + 2 * margin layout_parameters = { - 'left' : margin/figwidth, - 'right' : 1.-(margin/figwidth), - 'bottom' : margin/figheight, - 'top' : 1.-(margin/figheight), - 'wspace' : spacing/img_size, - 'hspace' : spacing/img_size * l_width/l_height + 'left': margin/figwidth, + 'right': 1.-(margin/figwidth), + 'bottom': margin/figheight, + 'top': 1.-(margin/figheight), + 'wspace': spacing/img_size, + 'hspace': spacing/img_size * l_width/l_height } return layout_parameters, rows, figwidth, figheight @@ -630,7 +707,8 @@ def plot_concept_attribution_maps(self, Additional parameters passed to `plt.imshow()`. """ - self.plot_concept_attribution_legend(nb_most_important_concepts=nb_most_important_concepts) + self.plot_concept_attribution_legend( + nb_most_important_concepts=nb_most_important_concepts) # Take into account single vs multiple images array if len(images.shape) == 3: @@ -638,7 +716,8 @@ def plot_concept_attribution_maps(self, # Configure the subplots layout_configuration, rows, figwidth, figheight = \ - self.compute_subplots_layout_parameters(images=images, cols=cols, img_size=img_size) + self.compute_subplots_layout_parameters( + images=images, cols=cols, img_size=img_size) fig = plt.figure() fig.set_size_inches(figwidth, figheight) @@ -666,13 +745,13 @@ def plot_concept_attribution_maps(self, **plot_kwargs) def plot_concept_attribution_map(self, - image: np.ndarray, - most_important_concepts: np.ndarray, - nb_most_important_concepts: int = 5, - filter_percentile: int = 90, - clip_percentile: Optional[float] = 10, - alpha: float = 0.65, - **plot_kwargs): + image: np.ndarray, + most_important_concepts: np.ndarray, + nb_most_important_concepts: int = 5, + filter_percentile: int = 90, + clip_percentile: Optional[float] = 10, + alpha: float = 0.65, + **plot_kwargs): """ Display the concepts attribution map for a single image given in argument. @@ -702,15 +781,16 @@ def plot_concept_attribution_map(self, most_important_concepts = most_important_concepts[:nb_most_important_concepts] # Find the colors corresponding to the importances - global_color_index_order = np.argsort(self.sensitivity.importances)[::-1] + global_color_index_order = np.argsort( + self.sensitivity.importances)[::-1] local_color_index_order = [np.where(global_color_index_order == local_c)[0][0] for local_c in most_important_concepts] local_cmap = np.array(self.sensitivity.cmaps)[local_color_index_order] if image.shape[0] == 3: - dsize = image.shape[1:3] # pytorch + dsize = image.shape[1:3] # pytorch else: - dsize = image.shape[0:2] # tf + dsize = image.shape[0:2] # tf BaseCraft._show(image, **plot_kwargs) image_u = self.transform(image)[0] @@ -718,7 +798,8 @@ def plot_concept_attribution_map(self, heatmap = image_u[:, :, c_id] # only show concept if excess N-th percentile - sigma = np.percentile(np.array(heatmap).flatten(), filter_percentile) + sigma = np.percentile( + np.array(heatmap).flatten(), filter_percentile) heatmap = heatmap * np.array(heatmap > sigma, np.float32) # resize the heatmap before cliping @@ -727,12 +808,13 @@ def plot_concept_attribution_map(self, if clip_percentile: heatmap = _clip_percentile(heatmap, clip_percentile) - BaseCraft._show(heatmap, cmap=local_cmap[::-1][i], alpha=alpha, **plot_kwargs) + BaseCraft._show( + heatmap, cmap=local_cmap[::-1][i], alpha=alpha, **plot_kwargs) def plot_image_concepts(self, img: np.ndarray, - display_importance_order: DisplayImportancesOrder = \ - DisplayImportancesOrder.GLOBAL, + display_importance_order: DisplayImportancesOrder = + DisplayImportancesOrder.GLOBAL, nb_most_important_concepts: int = 5, filter_percentile: int = 90, clip_percentile: Optional[float] = 10, @@ -776,7 +858,8 @@ def plot_image_concepts(self, if display_importance_order == DisplayImportancesOrder.LOCAL: # compute the importances for the sample input in argument importances = self.estimate_importance(inputs=img) - most_important_concepts = np.argsort(importances)[::-1][:nb_most_important_concepts] + most_important_concepts = np.argsort( + importances)[::-1][:nb_most_important_concepts] else: # use the global importances computed on the whole dataset importances = self.sensitivity.importances @@ -787,7 +870,8 @@ def plot_image_concepts(self, # the crops, and the central part to display the heatmap nb_rows = ceil(len(most_important_concepts) / 2.0) nb_cols = 4 - gs_main = fig.add_gridspec(nb_rows, nb_cols, hspace=0.4, width_ratios=[0.2, 0.4, 0.2, 0.4]) + gs_main = fig.add_gridspec( + nb_rows, nb_cols, hspace=0.4, width_ratios=[0.2, 0.4, 0.2, 0.4]) # Central image # @@ -802,8 +886,8 @@ def plot_image_concepts(self, # Concepts: creation of the axes on left and right of the image for the concepts # - gs_concepts_axes = [gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_main[i, 0]) - for i in range(nb_rows)] + gs_concepts_axes = [gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_main[i, 0]) + for i in range(nb_rows)] gs_right = [gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_main[i, 2]) for i in range(nb_rows)] gs_concepts_axes.extend(gs_right) @@ -812,7 +896,8 @@ def plot_image_concepts(self, nb_crops = 6 # compute the right color to use for the crops - global_color_index_order = np.argsort(self.sensitivity.importances)[::-1] + global_color_index_order = np.argsort( + self.sensitivity.importances)[::-1] local_color_index_order = [np.where(global_color_index_order == local_c)[0][0] for local_c in most_important_concepts] local_cmap = np.array(self.sensitivity.cmaps)[local_color_index_order] @@ -821,22 +906,24 @@ def plot_image_concepts(self, cmap = local_cmap[i] # use a ghost invisible subplot only to have a border around the crops - ghost_axe = fig.add_subplot(gs_concepts_axes[i][:,:]) + ghost_axe = fig.add_subplot(gs_concepts_axes[i][:, :]) ghost_axe.set_title(f"{c_id}", color=cmap(1.0)) ghost_axe.axis('off') - inset_axes = ghost_axe.inset_axes([-0.04, -0.04, 1.08, 1.08]) # outer border + inset_axes = ghost_axe.inset_axes( + [-0.04, -0.04, 1.08, 1.08]) # outer border inset_axes.set_xticks([]) inset_axes.set_yticks([]) - for spine in inset_axes.spines.values(): # border color + for spine in inset_axes.spines.values(): # border color spine.set_edgecolor(color=cmap(1.0)) spine.set_linewidth(3) # draw each crop for this concept - gs_current = gridspec.GridSpecFromSubplotSpec(2, 3, subplot_spec= - gs_concepts_axes[i][:,:]) + gs_current = gridspec.GridSpecFromSubplotSpec( + 2, 3, subplot_spec=gs_concepts_axes[i][:, :]) - best_crops_ids = np.argsort(self.factorization.crops_u[:, c_id])[::-1][:nb_crops] + best_crops_ids = np.argsort(self.factorization.crops_u[:, c_id])[ + ::-1][:nb_crops] best_crops = np.array(self.factorization.crops)[best_crops_ids] for i in range(nb_crops): axe = plt.Subplot(fig, gs_current[i // 3, i % 3]) @@ -844,9 +931,10 @@ def plot_image_concepts(self, BaseCraft._show(best_crops[i]) # Right plot: importances - importance_axe = gridspec.GridSpecFromSubplotSpec(3, 2, width_ratios=[0.1, 0.9], - height_ratios=[0.15, 0.6, 0.15], - subplot_spec=gs_main[:, 3]) + importance_axe = gridspec.GridSpecFromSubplotSpec(3, 2, width_ratios=[0.1, 0.9], + height_ratios=[ + 0.15, 0.6, 0.15], + subplot_spec=gs_main[:, 3]) fig.add_subplot(importance_axe[1, 1]) self.plot_concepts_importances(importances=importances, display_importance_order=display_importance_order, diff --git a/xplique/concepts/craft_manager.py b/xplique/concepts/craft_manager.py index fb94bf17..7b286bd5 100644 --- a/xplique/concepts/craft_manager.py +++ b/xplique/concepts/craft_manager.py @@ -6,9 +6,10 @@ import numpy as np -from ..types import Callable, Optional +from ..types import Callable, Optional, Union, Tuple from .craft import DisplayImportancesOrder + class BaseCraftManager(ABC): """ Base class implementing the CRAFT Concept Extraction Mechanism on multiple classes. @@ -31,11 +32,11 @@ class BaseCraftManager(ABC): @abstractmethod def __init__(self, - input_to_latent_model : Callable, - latent_to_logit_model : Callable, - inputs : np.ndarray, - labels : np.ndarray, - list_of_class_of_interest : Optional[list] = None): + input_to_latent_model: Callable, + latent_to_logit_model: Callable, + inputs: np.ndarray, + labels: np.ndarray, + list_of_class_of_interest: Optional[list] = None): self.input_to_latent_model = input_to_latent_model self.latent_to_logit_model = latent_to_logit_model self.inputs = inputs @@ -45,7 +46,7 @@ def __init__(self, # take all the classes list_of_class_of_interest = np.array(list(set(labels))) self.list_of_class_of_interest = list_of_class_of_interest - self.craft_instances = None + self.craft_instances = {} @abstractmethod def compute_predictions(self): @@ -61,7 +62,10 @@ def compute_predictions(self): """ raise NotImplementedError - def fit(self, nb_samples_per_class: Optional[int] = None, verbose: bool = False): + def fit(self, + nb_samples_per_class: Optional[int] = None, + alpha_w: float = 1e-2, + verbose: bool = False): """ Fit the Craft models on their respective class of interest. @@ -70,6 +74,8 @@ def fit(self, nb_samples_per_class: Optional[int] = None, verbose: bool = False) nb_samples_per_class Number of samples to use to fit the Craft model. Default is None, which means that all the samples will be used. + alpha_w + Constant that multiplies the NMF regularization terms of `W`. verbose If True, then print the current class CRAFT is fitting, otherwise no textual output will be printed. @@ -85,7 +91,8 @@ def fit(self, nb_samples_per_class: Optional[int] = None, verbose: bool = False) if nb_samples_per_class is not None: class_inputs = class_inputs[:nb_samples_per_class] class_labels = class_labels[:nb_samples_per_class] - craft_instance.fit(class_inputs, class_id=class_of_interest) + craft_instance.fit( + class_inputs, class_id=class_of_interest, alpha_w=alpha_w) def estimate_importance(self, nb_design: int = 32, verbose: bool = False): """ @@ -104,6 +111,32 @@ def estimate_importance(self, nb_design: int = 32, verbose: bool = False): print(f'Estimating importances for class {class_of_interest} ') craft_instance.estimate_importance(nb_design=nb_design) + def set_concept_attribution_cmap(self, + class_id: int, + cmaps: Optional[Union[Tuple, str]] = None): + """ + Set the colormap used for the concepts displayed in the attribution maps, + form the class 'class_id'. + + Parameters + ---------- + class_id + The class to explain. + cmaps + The list of colors associated with each concept. + Can be either: + - A list of (r, g, b) colors to use as a base for the colormap. + - A colormap name compatible with `plt.get_cmap(cmap)`. + """ + self.craft_instances[class_id].sensitivity.set_concept_attribution_cmap( + cmaps) + + +class CraftManagerImageVisualizationMixin(): + """ + Class containing image visualization methods for CraftManager. + """ + def plot_concepts_importances(self, class_id: int, nb_most_important_concepts: int = 5, @@ -121,9 +154,10 @@ def plot_concepts_importances(self, If True, then print the importance value of each concept, otherwise no textual output will be printed. """ - self.craft_instances[class_id].plot_concepts_importances(importances = None, - nb_most_important_concepts=nb_most_important_concepts, - verbose=verbose) + self.craft_instances[class_id]\ + .plot_concepts_importances(importances=None, + nb_most_important_concepts=nb_most_important_concepts, + verbose=verbose) def plot_concepts_crops(self, class_id: int, nb_crops: int = 10, @@ -143,13 +177,13 @@ def plot_concepts_crops(self, Default is None. """ self.craft_instances[class_id].plot_concepts_crops(nb_crops=nb_crops, - nb_most_important_concepts=nb_most_important_concepts) + nb_most_important_concepts=nb_most_important_concepts) def plot_image_concepts(self, img: np.ndarray, class_id: int, - display_importance_order: DisplayImportancesOrder = \ - DisplayImportancesOrder.GLOBAL, + display_importance_order: DisplayImportancesOrder = + DisplayImportancesOrder.GLOBAL, nb_most_important_concepts: int = 5, filter_percentile: int = 90, clip_percentile: Optional[float] = 10, @@ -188,9 +222,9 @@ def plot_image_concepts(self, Path the file will be saved at. If None, the function will call plt.show(). """ self.craft_instances[class_id].plot_image_concepts(img, - display_importance_order=display_importance_order, - nb_most_important_concepts=nb_most_important_concepts, - filter_percentile=filter_percentile, - clip_percentile=clip_percentile, - alpha=alpha, - filepath=filepath) + display_importance_order=display_importance_order, + nb_most_important_concepts=nb_most_important_concepts, + filter_percentile=filter_percentile, + clip_percentile=clip_percentile, + alpha=alpha, + filepath=filepath) diff --git a/xplique/concepts/craft_tf.py b/xplique/concepts/craft_tf.py index 1569cdc9..7a95732b 100644 --- a/xplique/concepts/craft_tf.py +++ b/xplique/concepts/craft_tf.py @@ -7,13 +7,13 @@ import tensorflow as tf import numpy as np -from .craft import BaseCraft -from .craft_manager import BaseCraftManager +from .craft import BaseCraft, CraftImageVisualizationMixin +from .craft_manager import BaseCraftManager, CraftManagerImageVisualizationMixin -class CraftTf(BaseCraft): +class BaseCraftTf(BaseCraft): """ - Class implementing the CRAFT Concept Extraction Mechanism on Tensorflow. + Base class implementing the CRAFT Concept Extraction Mechanism on Tensorflow. Parameters ---------- @@ -33,11 +33,12 @@ class CraftTf(BaseCraft): patch_size The size of the patches to extract from the input data. Default is 64. """ - def __init__(self, input_to_latent_model : Callable, - latent_to_logit_model : Callable, - number_of_concepts: int = 20, - batch_size: int = 64, - patch_size: int = 64): + + def __init__(self, input_to_latent_model: Callable, + latent_to_logit_model: Callable, + number_of_concepts: int = 20, + batch_size: int = 64, + patch_size: int = 64): super().__init__(input_to_latent_model, latent_to_logit_model, number_of_concepts, @@ -48,9 +49,9 @@ def __init__(self, input_to_latent_model : Callable, keras_base_layer = tf.keras.Model is_tf_model = issubclass(type(input_to_latent_model), keras_base_layer) & \ - issubclass(type(latent_to_logit_model), keras_base_layer) + issubclass(type(latent_to_logit_model), keras_base_layer) if not is_tf_model: - raise TypeError('input_to_latent_model and latent_to_logit_model are not '\ + raise TypeError('input_to_latent_model and latent_to_logit_model are not ' 'Tensorflow models') def _latent_predict(self, inputs: tf.Tensor): @@ -110,16 +111,19 @@ def _extract_patches(self, inputs: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]: strides = int(self.patch_size * 0.80) patches = tf.image.extract_patches(images=inputs, - sizes=[1, self.patch_size, self.patch_size, 1], + sizes=[1, self.patch_size, + self.patch_size, 1], strides=[1, strides, strides, 1], rates=[1, 1, 1, 1], padding='VALID') - patches = tf.reshape(patches, (-1, self.patch_size, self.patch_size, inputs.shape[-1])) + patches = tf.reshape( + patches, (-1, self.patch_size, self.patch_size, inputs.shape[-1])) # encode the patches and obtain the activations input_width, input_height = inputs.shape[1], inputs.shape[2] activations = self._latent_predict(tf.image.resize(patches, - (input_width, input_height), + (input_width, + input_height), method="bicubic")) assert np.min(activations) >= 0.0, "Activations must be positive." @@ -138,9 +142,34 @@ def _to_np_array(self, inputs: tf.Tensor, dtype: type): return np.array(inputs, dtype) -class CraftManagerTf(BaseCraftManager): +class CraftTf(BaseCraftTf, CraftImageVisualizationMixin): """ - Class implementing the CraftManager on Tensorflow. + Base class implementing the CRAFT Concept Extraction Mechanism on Tensorflow, + adapted for image processing. + + Parameters + ---------- + input_to_latent_model + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + Must be a Tensorflow model (tf.keras.engine.base_layer.Layer) accepting + data of shape (n_samples, height, width, channels). + latent_to_logit_model + The second part of the model taking activation and returning + logits, h(.) in the original paper. + Must be a Tensorflow model (tf.keras.engine.base_layer.Layer). + number_of_concepts + The number of concepts to extract. Default is 20. + batch_size + The batch size to use during training and prediction. Default is 64. + patch_size + The size of the patches to extract from the input data. Default is 64. + """ + + +class BaseCraftManagerTf(BaseCraftManager): + """ + Base class implementing the CraftManager on Tensorflow. This manager creates one CraftTf instance per class to explain. Parameters @@ -167,19 +196,19 @@ class CraftManagerTf(BaseCraftManager): patch_size The size of the patches (crops) to extract from the input data. Default is 64. """ - def __init__(self, input_to_latent_model : Callable, - latent_to_logit_model : Callable, - inputs : np.ndarray, - labels : np.ndarray, - list_of_class_of_interest : Optional[list] = None, - number_of_concepts: int = 20, - batch_size: int = 64, - patch_size: int = 64): + + def __init__(self, input_to_latent_model: Callable, + latent_to_logit_model: Callable, + inputs: np.ndarray, + labels: np.ndarray, + list_of_class_of_interest: Optional[list] = None, + number_of_concepts: int = 20, + batch_size: int = 64, + patch_size: int = 64): super().__init__(input_to_latent_model, latent_to_logit_model, inputs, labels, list_of_class_of_interest) - self.craft_instances = {} for class_of_interest in self.list_of_class_of_interest: craft = CraftTf(input_to_latent_model, latent_to_logit_model, number_of_concepts, batch_size, patch_size) @@ -196,5 +225,36 @@ def compute_predictions(self): the predictions """ y_preds = np.array(tf.argmax(self.latent_to_logit_model.predict( - self.input_to_latent_model.predict(self.inputs)), 1)) + self.input_to_latent_model.predict(self.inputs)), 1)) return y_preds + + +class CraftManagerTf(BaseCraftManagerTf, CraftManagerImageVisualizationMixin): + """ + Class implementing the CraftManager on Tensorflow, adapted for image processing. + This manager creates one CraftTf instance per class to explain. + + Parameters + ---------- + input_to_latent_model + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + Must return positive activations. + latent_to_logit_model + The second part of the model taking activation and returning + logits, h(.) in the original paper. + inputs + Input data of shape (n_samples, height, width, channels). + (x1, x2, ..., xn) in the paper. + labels + Labels of the inputs of shape (n_samples, class_id) + list_of_class_of_interest + A list of the classes id to explain. The manager will instanciate one + CraftTf object per element of this list. + number_of_concepts + The number of concepts to extract. Default is 20. + batch_size + The batch size to use during training and prediction. Default is 64. + patch_size + The size of the patches (crops) to extract from the input data. Default is 64. + """ diff --git a/xplique/concepts/craft_torch.py b/xplique/concepts/craft_torch.py index f429cce4..2b659df5 100644 --- a/xplique/concepts/craft_torch.py +++ b/xplique/concepts/craft_torch.py @@ -3,19 +3,21 @@ """ from typing import Callable, Optional, Tuple +from types import MethodType from math import ceil import torch from torch import nn import numpy as np -from .craft import BaseCraft -from .craft_manager import BaseCraftManager +from .craft import BaseCraft, CraftImageVisualizationMixin +from .craft_manager import BaseCraftManager, CraftManagerImageVisualizationMixin + def _batch_inference(model: torch.nn.Module, dataset: torch.Tensor, batch_size: int = 128, resize: Optional[int] = None, - device: str='cuda') -> torch.Tensor: + device: str = 'cuda') -> torch.Tensor: """ Compute the model predictions of the input images. @@ -59,9 +61,9 @@ def _batch_inference(model: torch.nn.Module, return results -class CraftTorch(BaseCraft): +class BaseCraftTorch(BaseCraft): """ - Class Implementing the CRAFT Concept Extraction Mechanism on Pytorch. + Base class implementing the CRAFT Concept Extraction Mechanism on Pytorch. Parameters ---------- @@ -85,22 +87,25 @@ class CraftTorch(BaseCraft): """ def __init__(self, input_to_latent_model: Callable, - latent_to_logit_model: Callable, - number_of_concepts: int = 20, - batch_size: int = 64, - patch_size: int = 64, - device : str = 'cuda'): + latent_to_logit_model: Callable, + number_of_concepts: int = 20, + batch_size: int = 64, + patch_size: int = 64, + device: str = 'cuda'): super().__init__(input_to_latent_model, latent_to_logit_model, number_of_concepts, batch_size) self.patch_size = patch_size self.device = device # Check model type + is_method = isinstance(input_to_latent_model, MethodType) & \ + isinstance(latent_to_logit_model, MethodType) is_torch_model = issubclass(type(input_to_latent_model), torch.nn.modules.module.Module) & \ - issubclass(type(latent_to_logit_model), torch.nn.modules.module.Module) - if not is_torch_model: - raise TypeError('input_to_latent_model and latent_to_logit_model are not ' \ - 'Pytorch modules') + issubclass(type(latent_to_logit_model), + torch.nn.modules.module.Module) + if not (is_method or is_torch_model): + raise TypeError('input_to_latent_model and latent_to_logit_model are not ' + 'Pytorch modules nor methods') def _latent_predict(self, inputs: torch.Tensor, resize=None) -> torch.Tensor: """ @@ -118,7 +123,8 @@ def _latent_predict(self, inputs: torch.Tensor, resize=None) -> torch.Tensor: """ # inputs: (N, C, H, W) if len(inputs.shape) == 3: - inputs = inputs.unsqueeze(0) # add an extra dim in case we get only 1 image to predict + # add an extra dim in case we get only 1 image to predict + inputs = inputs.unsqueeze(0) activations = _batch_inference(self.input_to_latent_model, inputs, self.batch_size, resize, device=self.device) @@ -147,7 +153,8 @@ def _logit_predict(self, activations: np.ndarray, resize=None) -> torch.Tensor: if len(activations_perturbated.shape) == 4: # activations_perturbated: (N, H, W, C) -> (N, C, H, W) - activations_perturbated = activations_perturbated.permute(0, 3, 1, 2) + activations_perturbated = activations_perturbated.permute( + 0, 3, 1, 2) y_pred = _batch_inference(self.latent_to_logit_model, activations_perturbated, self.batch_size, resize, device=self.device) @@ -176,7 +183,8 @@ def _extract_patches(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, np.ndarr # extract patches from the input data, keep patches on cpu strides = int(self.patch_size * 0.80) - patches = torch.nn.functional.unfold(inputs, kernel_size=self.patch_size, stride=strides) + patches = torch.nn.functional.unfold( + inputs, kernel_size=self.patch_size, stride=strides) patches = patches.transpose(1, 2).contiguous().view(-1, num_channels, self.patch_size, self.patch_size) @@ -205,9 +213,36 @@ def _to_np_array(self, inputs: torch.Tensor, dtype: type = None): return res -class CraftManagerTorch(BaseCraftManager): +class CraftTorch(BaseCraftTorch, CraftImageVisualizationMixin): + """ + Class Implementing the CRAFT Concept Extraction Mechanism on Pytorch, + adpated for image processing. + + Parameters + ---------- + input_to_latent_model + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + Must be a Pytorch model (torch.nn.modules.module.Module) accepting + data of shape (n_samples, channels, height, width). + latent_to_logit_model + The second part of the model taking activation and returning + logits, h(.) in the original paper. + Must be a Pytorch model (torch.nn.modules.module.Module). + number_of_concepts + The number of concepts to extract. Default is 20. + batch_size + The batch size to use during training and prediction. Default is 64. + patch_size + The size of the patches (crops) to extract from the input data. Default is 64. + device + The device to use. Default is 'cuda'. + """ + + +class BaseCraftManagerTorch(BaseCraftManager): """ - Class implementing the CraftManager on Tensorflow. + Base class implementing the CraftManager on Pytorch. This manager creates one CraftTorch instance per class to explain. Parameters @@ -234,22 +269,22 @@ class CraftManagerTorch(BaseCraftManager): patch_size The size of the patches (crops) to extract from the input data. Default is 64. """ - def __init__(self, input_to_latent_model : Callable, - latent_to_logit_model : Callable, - inputs : np.ndarray, - labels : np.ndarray, - list_of_class_of_interest : list = None, - number_of_concepts: int = 20, - batch_size: int = 64, - patch_size: int = 64, - device : str = 'cuda'): + + def __init__(self, input_to_latent_model: Callable, + latent_to_logit_model: Callable, + inputs: np.ndarray, + labels: np.ndarray, + list_of_class_of_interest: list = None, + number_of_concepts: int = 20, + batch_size: int = 64, + patch_size: int = 64, + device: str = 'cuda'): super().__init__(input_to_latent_model, latent_to_logit_model, inputs, labels, list_of_class_of_interest) self.batch_size = batch_size self.device = device - self.craft_instances = {} for class_of_interest in self.list_of_class_of_interest: craft = CraftTorch(input_to_latent_model, latent_to_logit_model, number_of_concepts, batch_size, patch_size, device) @@ -265,8 +300,41 @@ def compute_predictions(self): y_preds the predictions """ - model = nn.Sequential(self.input_to_latent_model, self.latent_to_logit_model) + model = nn.Sequential(self.input_to_latent_model, + self.latent_to_logit_model) activations = _batch_inference(model, self.inputs, self.batch_size, None, device=self.device) - y_preds = np.array(torch.argmax(activations, -1)) # pylint disable=no-member + y_preds = np.array(torch.argmax(activations, -1) + ) # pylint disable=no-member return y_preds + + +class CraftManagerTorch(BaseCraftManagerTorch, CraftManagerImageVisualizationMixin): + """ + Class implementing the CraftManager on Pytorch, adapted for image processing. + This manager creates one CraftTorch instance per class to explain. + + Parameters + ---------- + input_to_latent_model + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + Must return positive activations. + latent_to_logit_model + The second part of the model taking activation and returning + logits, h(.) in the original paper. + inputs + Input data of shape (n_samples, height, width, channels). + (x1, x2, ..., xn) in the paper. + labels + Labels of the inputs of shape (n_samples, class_id) + list_of_class_of_interest + A list of the classes id to explain. The manager will instanciate one + CraftTorch object per element of this list. + number_of_concepts + The number of concepts to extract. Default is 20. + batch_size + The batch size to use during training and prediction. Default is 64. + patch_size + The size of the patches (crops) to extract from the input data. Default is 64. + """ diff --git a/xplique/features_visualizations/preconditioning.py b/xplique/features_visualizations/preconditioning.py index f5758f8d..fc98fcc6 100644 --- a/xplique/features_visualizations/preconditioning.py +++ b/xplique/features_visualizations/preconditioning.py @@ -16,13 +16,6 @@ IMAGENET_SPECTRUM_URL = "https://storage.googleapis.com/serrelab/loupe/"\ "spectrums/imagenet_decorrelated.npy" -imagenet_color_correlation = tf.cast( - [[0.56282854, 0.58447580, 0.58447580], - [0.19482528, 0.00000000,-0.19482528], - [0.04329450,-0.10823626, 0.06494176]], tf.float32 -) - - def recorrelate_colors(images: tf.Tensor) -> tf.Tensor: """ Map uncorrelated colors to 'normal colors' by using empirical color @@ -39,6 +32,11 @@ def recorrelate_colors(images: tf.Tensor) -> tf.Tensor: images Images recorrelated. """ + imagenet_color_correlation = tf.cast( + [[0.56282854, 0.58447580, 0.58447580], + [0.19482528, 0.00000000,-0.19482528], + [0.04329450,-0.10823626, 0.06494176]], tf.float32 + ) images_flat = tf.reshape(images, [-1, 3]) images_flat = tf.matmul(images_flat, imagenet_color_correlation) return tf.reshape(images_flat, tf.shape(images)) diff --git a/xplique/utils_functions/object_detection.py b/xplique/utils_functions/object_detection.py index c8bbbe6b..cc1c1016 100644 --- a/xplique/utils_functions/object_detection.py +++ b/xplique/utils_functions/object_detection.py @@ -4,9 +4,6 @@ from typing import Tuple import tensorflow as tf -_EPSILON = tf.constant(1e-4) - - def _box_iou(boxes_a: tf.Tensor, boxes_b: tf.Tensor) -> tf.Tensor: """ Compute the intersection between two batched bounding boxes. @@ -40,6 +37,7 @@ def _box_iou(boxes_a: tf.Tensor, boxes_b: tf.Tensor) -> tf.Tensor: union_area = a_area + b_area - intersection_area + _EPSILON = tf.constant(1e-4) iou_score = intersection_area / (union_area + _EPSILON) return iou_score