Skip to content

Commit 5743025

Browse files
Craft updates to prepare Cockatiel.
This patch separates the methods dedicated to image processing/plotting into different sub classes, so that Cockatiel and NLP related classes will not have to inheritate these. Signed-off-by: Frederic Boisnard <frederic.boisnard@irt-saintexupery.com>
1 parent 2ce2816 commit 5743025

4 files changed

Lines changed: 173 additions & 26 deletions

File tree

xplique/concepts/craft.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def set_concept_attribution_cmap(self, cmaps: Optional[Union[Tuple, str]]=None):
129129
if len(self.cmaps) < len(self.most_important_concepts):
130130
raise RuntimeError(f'Not enough colors in cmaps ({len(self.cmaps)}) ' \
131131
f'compared to the number of important concepts ' \
132-
'({len(self.most_important_concepts)})')
132+
f'({len(self.most_important_concepts)})')
133133

134134
class DisplayImportancesOrder(Enum):
135135
"""
@@ -455,6 +455,44 @@ def _show(img, **kwargs):
455455
plt.imshow(img, **kwargs)
456456
plt.axis('off')
457457

458+
def _gen_best_concepts_crops(self,
459+
nb_crops: int = 10,
460+
nb_most_important_concepts: int = None) \
461+
-> Tuple[int, float, np.ndarray]:
462+
"""
463+
Generate the best concept crops for each concept.
464+
465+
Parameters
466+
----------
467+
nb_crops : int
468+
The number of crops (patches) to display per concept. Defaults to 10.
469+
nb_most_important_concepts : int
470+
The number of concepts to consider. If provided, only take into account
471+
nb_most_important_concepts, otherwise use them all.
472+
Default is None.
473+
Returns
474+
-------
475+
Tuple
476+
A tuple containing:
477+
- The current concept id.
478+
- The overall importance score for this concept.
479+
- An array containing the best crops for this concept.
480+
"""
481+
most_important_concepts = self.sensitivity.most_important_concepts
482+
if nb_most_important_concepts is not None:
483+
most_important_concepts = most_important_concepts[:nb_most_important_concepts]
484+
485+
for c_id in most_important_concepts:
486+
best_crops_ids = np.argsort(self.factorization.crops_u[:, c_id])[::-1][:nb_crops]
487+
best_crops = np.array(self.factorization.crops)[best_crops_ids]
488+
c_id_importance = self.sensitivity.importances[c_id]
489+
yield c_id, c_id_importance, best_crops
490+
491+
492+
class CraftImageVisualizationMixin():
493+
"""
494+
Class containing image visualization methods for Craft.
495+
"""
458496
def plot_concepts_crops(self,
459497
nb_crops: int = 10,
460498
nb_most_important_concepts: int = None,
@@ -474,17 +512,11 @@ def plot_concepts_crops(self,
474512
If True, then print the importance value of each concept,
475513
otherwise no textual output will be printed.
476514
"""
477-
most_important_concepts = self.sensitivity.most_important_concepts
478-
if nb_most_important_concepts is not None:
479-
most_important_concepts = most_important_concepts[:nb_most_important_concepts]
480-
481-
for c_id in most_important_concepts:
482-
best_crops_ids = np.argsort(self.factorization.crops_u[:, c_id])[::-1][:nb_crops]
483-
best_crops = np.array(self.factorization.crops)[best_crops_ids]
484-
515+
for c_id, c_id_importance, best_crops in \
516+
self._gen_best_concepts_crops(nb_crops, nb_most_important_concepts):
485517
if verbose:
486518
print(f"Concept {c_id} has an importance value of " \
487-
f"{self.sensitivity.importances[c_id]:.2f}")
519+
f"{c_id_importance:.2f}")
488520
plt.figure(figsize=(7, (2.5/2)*ceil(nb_crops/5)))
489521
for i in range(nb_crops):
490522
plt.subplot(ceil(nb_crops/5), 5, i+1)

xplique/concepts/craft_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ def estimate_importance(self, nb_design: int = 32, verbose: bool = False):
104104
print(f'Estimating importances for class {class_of_interest} ')
105105
craft_instance.estimate_importance(nb_design=nb_design)
106106

107+
108+
class CraftManagerImageVisualizationMixin():
109+
"""
110+
Class containing image visualization methods for CraftManager.
111+
"""
107112
def plot_concepts_importances(self,
108113
class_id: int,
109114
nb_most_important_concepts: int = 5,

xplique/concepts/craft_tf.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
import tensorflow as tf
88
import numpy as np
99

10-
from .craft import BaseCraft
11-
from .craft_manager import BaseCraftManager
12-
13-
class CraftTf(BaseCraft):
10+
from .craft import BaseCraft, CraftImageVisualizationMixin
11+
from .craft_manager import BaseCraftManager, CraftManagerImageVisualizationMixin
1412

13+
class BaseCraftTf(BaseCraft):
1514
"""
16-
Class implementing the CRAFT Concept Extraction Mechanism on Tensorflow.
15+
Base class implementing the CRAFT Concept Extraction Mechanism on Tensorflow.
1716
1817
Parameters
1918
----------
@@ -137,10 +136,34 @@ def _to_np_array(self, inputs: tf.Tensor, dtype: type):
137136
"""
138137
return np.array(inputs, dtype)
139138

139+
class CraftTf(BaseCraftTf, CraftImageVisualizationMixin):
140+
"""
141+
Base class implementing the CRAFT Concept Extraction Mechanism on Tensorflow,
142+
adapted for image processing.
140143
141-
class CraftManagerTf(BaseCraftManager):
144+
Parameters
145+
----------
146+
input_to_latent_model
147+
The first part of the model taking an input and returning
148+
positive activations, g(.) in the original paper.
149+
Must be a Tensorflow model (tf.keras.engine.base_layer.Layer) accepting
150+
data of shape (n_samples, height, width, channels).
151+
latent_to_logit_model
152+
The second part of the model taking activation and returning
153+
logits, h(.) in the original paper.
154+
Must be a Tensorflow model (tf.keras.engine.base_layer.Layer).
155+
number_of_concepts
156+
The number of concepts to extract. Default is 20.
157+
batch_size
158+
The batch size to use during training and prediction. Default is 64.
159+
patch_size
160+
The size of the patches to extract from the input data. Default is 64.
142161
"""
143-
Class implementing the CraftManager on Tensorflow.
162+
163+
164+
class BaseCraftManagerTf(BaseCraftManager):
165+
"""
166+
Base class implementing the CraftManager on Tensorflow.
144167
This manager creates one CraftTf instance per class to explain.
145168
146169
Parameters
@@ -198,3 +221,33 @@ def compute_predictions(self):
198221
y_preds = np.array(tf.argmax(self.latent_to_logit_model.predict(
199222
self.input_to_latent_model.predict(self.inputs)), 1))
200223
return y_preds
224+
225+
class CraftManagerTf(BaseCraftManagerTf, CraftManagerImageVisualizationMixin):
226+
"""
227+
Class implementing the CraftManager on Tensorflow, adapted for image processing.
228+
This manager creates one CraftTf instance per class to explain.
229+
230+
Parameters
231+
----------
232+
input_to_latent_model
233+
The first part of the model taking an input and returning
234+
positive activations, g(.) in the original paper.
235+
Must return positive activations.
236+
latent_to_logit_model
237+
The second part of the model taking activation and returning
238+
logits, h(.) in the original paper.
239+
inputs
240+
Input data of shape (n_samples, height, width, channels).
241+
(x1, x2, ..., xn) in the paper.
242+
labels
243+
Labels of the inputs of shape (n_samples, class_id)
244+
list_of_class_of_interest
245+
A list of the classes id to explain. The manager will instanciate one
246+
CraftTf object per element of this list.
247+
number_of_concepts
248+
The number of concepts to extract. Default is 20.
249+
batch_size
250+
The batch size to use during training and prediction. Default is 64.
251+
patch_size
252+
The size of the patches (crops) to extract from the input data. Default is 64.
253+
"""

xplique/concepts/craft_torch.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
"""
44

55
from typing import Callable, Optional, Tuple
6+
from types import MethodType
67
from math import ceil
78
import torch
89
from torch import nn
910
import numpy as np
1011

11-
from .craft import BaseCraft
12-
from .craft_manager import BaseCraftManager
12+
from .craft import BaseCraft, CraftImageVisualizationMixin
13+
from .craft_manager import BaseCraftManager, CraftManagerImageVisualizationMixin
1314

1415
def _batch_inference(model: torch.nn.Module,
1516
dataset: torch.Tensor,
@@ -59,9 +60,9 @@ def _batch_inference(model: torch.nn.Module,
5960
return results
6061

6162

62-
class CraftTorch(BaseCraft):
63+
class BaseCraftTorch(BaseCraft):
6364
"""
64-
Class Implementing the CRAFT Concept Extraction Mechanism on Pytorch.
65+
Base class implementing the CRAFT Concept Extraction Mechanism on Pytorch.
6566
6667
Parameters
6768
----------
@@ -83,7 +84,6 @@ class CraftTorch(BaseCraft):
8384
device
8485
The device to use. Default is 'cuda'.
8586
"""
86-
8787
def __init__(self, input_to_latent_model: Callable,
8888
latent_to_logit_model: Callable,
8989
number_of_concepts: int = 20,
@@ -96,11 +96,13 @@ def __init__(self, input_to_latent_model: Callable,
9696
self.device = device
9797

9898
# Check model type
99+
is_method = isinstance(input_to_latent_model, MethodType) & \
100+
isinstance(latent_to_logit_model, MethodType)
99101
is_torch_model = issubclass(type(input_to_latent_model), torch.nn.modules.module.Module) & \
100102
issubclass(type(latent_to_logit_model), torch.nn.modules.module.Module)
101-
if not is_torch_model:
103+
if not (is_method or is_torch_model):
102104
raise TypeError('input_to_latent_model and latent_to_logit_model are not ' \
103-
'Pytorch modules')
105+
'Pytorch modules nor methods')
104106

105107
def _latent_predict(self, inputs: torch.Tensor, resize=None) -> torch.Tensor:
106108
"""
@@ -204,10 +206,35 @@ def _to_np_array(self, inputs: torch.Tensor, dtype: type = None):
204206
return res.astype(dtype)
205207
return res
206208

209+
class CraftTorch(BaseCraftTorch, CraftImageVisualizationMixin):
210+
"""
211+
Class Implementing the CRAFT Concept Extraction Mechanism on Pytorch,
212+
adpated for image processing.
207213
208-
class CraftManagerTorch(BaseCraftManager):
214+
Parameters
215+
----------
216+
input_to_latent_model
217+
The first part of the model taking an input and returning
218+
positive activations, g(.) in the original paper.
219+
Must be a Pytorch model (torch.nn.modules.module.Module) accepting
220+
data of shape (n_samples, channels, height, width).
221+
latent_to_logit_model
222+
The second part of the model taking activation and returning
223+
logits, h(.) in the original paper.
224+
Must be a Pytorch model (torch.nn.modules.module.Module).
225+
number_of_concepts
226+
The number of concepts to extract. Default is 20.
227+
batch_size
228+
The batch size to use during training and prediction. Default is 64.
229+
patch_size
230+
The size of the patches (crops) to extract from the input data. Default is 64.
231+
device
232+
The device to use. Default is 'cuda'.
233+
"""
234+
235+
class BaseCraftManagerTorch(BaseCraftManager):
209236
"""
210-
Class implementing the CraftManager on Tensorflow.
237+
Base class implementing the CraftManager on Pytorch.
211238
This manager creates one CraftTorch instance per class to explain.
212239
213240
Parameters
@@ -270,3 +297,33 @@ def compute_predictions(self):
270297
device=self.device)
271298
y_preds = np.array(torch.argmax(activations, -1)) # pylint disable=no-member
272299
return y_preds
300+
301+
class CraftManagerTorch(BaseCraftManagerTorch, CraftManagerImageVisualizationMixin):
302+
"""
303+
Class implementing the CraftManager on Pytorch, adapted for image processing.
304+
This manager creates one CraftTorch instance per class to explain.
305+
306+
Parameters
307+
----------
308+
input_to_latent_model
309+
The first part of the model taking an input and returning
310+
positive activations, g(.) in the original paper.
311+
Must return positive activations.
312+
latent_to_logit_model
313+
The second part of the model taking activation and returning
314+
logits, h(.) in the original paper.
315+
inputs
316+
Input data of shape (n_samples, height, width, channels).
317+
(x1, x2, ..., xn) in the paper.
318+
labels
319+
Labels of the inputs of shape (n_samples, class_id)
320+
list_of_class_of_interest
321+
A list of the classes id to explain. The manager will instanciate one
322+
CraftTorch object per element of this list.
323+
number_of_concepts
324+
The number of concepts to extract. Default is 20.
325+
batch_size
326+
The batch size to use during training and prediction. Default is 64.
327+
patch_size
328+
The size of the patches (crops) to extract from the input data. Default is 64.
329+
"""

0 commit comments

Comments
 (0)