Skip to content

Commit fd4b5c8

Browse files
New CAM Method: ShapleyCAM (#550)
* ShapleyCAM Weighting the activation maps using Gradient and Hessian-Vector Product. * name * ReST example * comments * Update README.md * Update README.md * Update README.md * update a simpler version * comments * forward function in shapely_cam.py still needed This is because the calculation of the Hessian-vector product (HVP) requires the computation graph to be retained, see comments in line 37 or 38. * delete forward function in shapley_cam.py * comments
1 parent b1cab2d commit fd4b5c8

File tree

7 files changed

+121
-18
lines changed

7 files changed

+121
-18
lines changed

README.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
4747
| Deep Feature Factorizations | Non Negative Matrix Factorization on the 2D activations |
4848
| KPCA-CAM | Like EigenCAM but with Kernel PCA instead of PCA |
4949
| FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. |
50+
| ShapleyCAM | Weight the activations using the gradient and Hessian-vector product.|
5051
## Visual Examples
5152

5253
| What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class |
@@ -362,4 +363,8 @@ Sachin Karmani, Thanushon Sivakaran, Gaurav Prasad, Mehmet Ali, Wenbo Yang, Shey
362363
https://hal.science/hal-02963298/document <br>
363364
`Features Understanding in 3D CNNs for Actions Recognition in Video
364365
Kazi Ahmed Asif Fuad, Pierre-Etienne Martin, Romain Giot, Romain
365-
Bourqui, Jenny Benois-Pineau, Akka Zemmar`
366+
Bourqui, Jenny Benois-Pineau, Akka Zemmar`
367+
368+
https://arxiv.org/abs/2501.06261 <br>
369+
`CAMs as Shapley Value-based Explainers
370+
Huaiguang Cai`

cam.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
from pytorch_grad_cam import (
88
GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
99
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
10-
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM
10+
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM
1111
)
1212
from pytorch_grad_cam import GuidedBackpropReLUModel
1313
from pytorch_grad_cam.utils.image import (
1414
show_cam_on_image, deprocess_image, preprocess_image
1515
)
16-
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
16+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputReST
1717

1818

1919
def get_args():
@@ -37,7 +37,7 @@ def get_args():
3737
'gradcam', 'fem', 'hirescam', 'gradcam++',
3838
'scorecam', 'xgradcam', 'ablationcam',
3939
'eigencam', 'eigengradcam', 'layercam',
40-
'fullgrad', 'gradcamelementwise', 'kpcacam'
40+
'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam'
4141
],
4242
help='CAM method')
4343

@@ -75,7 +75,8 @@ def get_args():
7575
"fullgrad": FullGrad,
7676
"fem": FEM,
7777
"gradcamelementwise": GradCAMElementWise,
78-
'kpcacam': KPCA_CAM
78+
'kpcacam': KPCA_CAM,
79+
'shapleycam': ShapleyCAM
7980
}
8081

8182
if args.device=='hpu':
@@ -109,7 +110,7 @@ def get_args():
109110
# If targets is None, the highest scoring category (for every member in the batch) will be used.
110111
# You can target specific categories by
111112
# targets = [ClassifierOutputTarget(281)]
112-
# targets = [ClassifierOutputTarget(281)]
113+
# targets = [ClassifierOutputReST(281)]
113114
targets = None
114115

115116
# Using the with statement ensures the context is freed, and you can

pytorch_grad_cam/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pytorch_grad_cam.grad_cam import GradCAM
2+
from pytorch_grad_cam.shapley_cam import ShapleyCAM
23
from pytorch_grad_cam.fem import FEM
34
from pytorch_grad_cam.hirescam import HiResCAM
45
from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise

pytorch_grad_cam/activations_and_gradients.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ class ActivationsAndGradients:
22
""" Class for extracting activations and
33
registering gradients from targetted intermediate layers """
44

5-
def __init__(self, model, target_layers, reshape_transform):
5+
def __init__(self, model, target_layers, reshape_transform, detach=True):
66
self.model = model
77
self.gradients = []
88
self.activations = []
99
self.reshape_transform = reshape_transform
10+
self.detach = detach
1011
self.handles = []
1112
for target_layer in target_layers:
1213
self.handles.append(
@@ -18,10 +19,12 @@ def __init__(self, model, target_layers, reshape_transform):
1819

1920
def save_activation(self, module, input, output):
2021
activation = output
21-
22-
if self.reshape_transform is not None:
23-
activation = self.reshape_transform(activation)
24-
self.activations.append(activation.cpu().detach())
22+
if self.detach:
23+
if self.reshape_transform is not None:
24+
activation = self.reshape_transform(activation)
25+
self.activations.append(activation.cpu().detach())
26+
else:
27+
self.activations.append(activation)
2528

2629
def save_gradient(self, module, input, output):
2730
if not hasattr(output, "requires_grad") or not output.requires_grad:
@@ -30,9 +33,12 @@ def save_gradient(self, module, input, output):
3033

3134
# Gradients are computed in reverse order
3235
def _store_grad(grad):
33-
if self.reshape_transform is not None:
34-
grad = self.reshape_transform(grad)
35-
self.gradients = [grad.cpu().detach()] + self.gradients
36+
if self.detach:
37+
if self.reshape_transform is not None:
38+
grad = self.reshape_transform(grad)
39+
self.gradients = [grad.cpu().detach()] + self.gradients
40+
else:
41+
self.gradients = [grad] + self.gradients
3642

3743
output.register_hook(_store_grad)
3844

pytorch_grad_cam/base_cam.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
compute_input_gradient: bool = False,
2020
uses_gradients: bool = True,
2121
tta_transforms: Optional[tta.Compose] = None,
22+
detach: bool = True,
2223
) -> None:
2324
self.model = model.eval()
2425
self.target_layers = target_layers
@@ -45,7 +46,8 @@ def __init__(
4546
else:
4647
self.tta_transforms = tta_transforms
4748

48-
self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)
49+
self.detach = detach
50+
self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform, self.detach)
4951

5052
""" Get a vector of weights for every channel in the target layer.
5153
Methods that return weights channels,
@@ -71,6 +73,8 @@ def get_cam_image(
7173
eigen_smooth: bool = False,
7274
) -> np.ndarray:
7375
weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
76+
if isinstance(activations, torch.Tensor):
77+
activations = activations.cpu().detach().numpy()
7478
# 2D conv
7579
if len(activations.shape) == 4:
7680
weighted_activations = weights[:, :, None, None] * activations
@@ -103,7 +107,13 @@ def forward(
103107
if self.uses_gradients:
104108
self.model.zero_grad()
105109
loss = sum([target(output) for target, output in zip(targets, outputs)])
106-
loss.backward(retain_graph=True)
110+
if self.detach:
111+
loss.backward(retain_graph=True)
112+
else:
113+
# keep the computational graph, create_graph = True is needed for hvp
114+
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
115+
# When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
116+
# loss.backward(retain_graph=True, create_graph=True)
107117
if 'hpu' in str(self.device):
108118
self.__htcore.mark_step()
109119

@@ -132,8 +142,12 @@ def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]
132142
def compute_cam_per_layer(
133143
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
134144
) -> np.ndarray:
135-
activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
136-
grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
145+
if self.detach:
146+
activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
147+
grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
148+
else:
149+
activations_list = [a for a in self.activations_and_grads.activations]
150+
grads_list = [g for g in self.activations_and_grads.gradients]
137151
target_size = self.get_target_width_height(input_tensor)
138152

139153
cam_per_target_layer = []

pytorch_grad_cam/shapley_cam.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Callable, List, Optional, Tuple
2+
from pytorch_grad_cam.base_cam import BaseCAM
3+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
4+
import torch
5+
import numpy as np
6+
7+
"""
8+
Weights the activation maps using the gradient and Hessian-Vector product.
9+
This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective.
10+
"""
11+
class ShapleyCAM(BaseCAM):
12+
def __init__(self, model, target_layers,
13+
reshape_transform=None):
14+
super(
15+
ShapleyCAM,
16+
self).__init__(
17+
model = model,
18+
target_layers = target_layers,
19+
reshape_transform = reshape_transform,
20+
compute_input_gradient = True,
21+
uses_gradients = True,
22+
detach = False)
23+
24+
def get_cam_weights(self,
25+
input_tensor,
26+
target_layer,
27+
target_category,
28+
activations,
29+
grads):
30+
31+
hvp = torch.autograd.grad(
32+
outputs=grads,
33+
inputs=activations,
34+
grad_outputs=activations,
35+
retain_graph=False,
36+
allow_unused=True
37+
)[0]
38+
# print(torch.max(hvp[0]).item()) # check if hvp is not all zeros
39+
if hvp is None:
40+
hvp = torch.tensor(0).to(self.device)
41+
else:
42+
if self.activations_and_grads.reshape_transform is not None:
43+
hvp = self.activations_and_grads.reshape_transform(hvp)
44+
45+
if self.activations_and_grads.reshape_transform is not None:
46+
activations = self.activations_and_grads.reshape_transform(activations)
47+
grads = self.activations_and_grads.reshape_transform(grads)
48+
49+
weight = (grads - 0.5 * hvp).detach().cpu().numpy()
50+
# 2D image
51+
if len(activations.shape) == 4:
52+
weight = np.mean(weight, axis=(2, 3))
53+
return weight
54+
# 3D image
55+
elif len(activations.shape) == 5:
56+
weight = np.mean(weight, axis=(2, 3, 4))
57+
return weight
58+
else:
59+
raise ValueError("Invalid grads shape."
60+
"Shape of grads should be 4 (2D image) or 5 (3D image).")

pytorch_grad_cam/utils/model_targets.py

+16
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,22 @@ def __call__(self, model_output):
2323
return torch.softmax(model_output, dim=-1)[:, self.category]
2424

2525

26+
class ClassifierOutputReST:
27+
"""
28+
Using both pre-softmax and post-softmax, proposed in https://arxiv.org/abs/2501.06261
29+
"""
30+
def __init__(self, category):
31+
self.category = category
32+
def __call__(self, model_output):
33+
if len(model_output.shape) == 1:
34+
target = torch.tensor([self.category], device=model_output.device)
35+
model_output = model_output.unsqueeze(0)
36+
return model_output[0][self.category] - torch.nn.functional.cross_entropy(model_output, target)
37+
else:
38+
target = torch.tensor([self.category] * model_output.shape[0], device=model_output.device)
39+
return model_output[:,self.category] - torch.nn.functional.cross_entropy(model_output, target)
40+
41+
2642
class BinaryClassifierOutputTarget:
2743
def __init__(self, category):
2844
self.category = category

0 commit comments

Comments
 (0)