Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New CAM Method: ShapleyCAM #550

Merged
merged 12 commits into from
Jan 19, 2025
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
| Deep Feature Factorizations | Non Negative Matrix Factorization on the 2D activations |
| KPCA-CAM | Like EigenCAM but with Kernel PCA instead of PCA |
| FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. |
| ShapleyCAM | Weight the activations using the gradient and Hessian-vector product.|
## Visual Examples

| What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class |
Expand Down Expand Up @@ -362,4 +363,8 @@ Sachin Karmani, Thanushon Sivakaran, Gaurav Prasad, Mehmet Ali, Wenbo Yang, Shey
https://hal.science/hal-02963298/document <br>
`Features Understanding in 3D CNNs for Actions Recognition in Video
Kazi Ahmed Asif Fuad, Pierre-Etienne Martin, Romain Giot, Romain
Bourqui, Jenny Benois-Pineau, Akka Zemmar`
Bourqui, Jenny Benois-Pineau, Akka Zemmar`

https://arxiv.org/abs/2501.06261 <br>
`CAMs as Shapley Value-based Explainers
Huaiguang Cai`
11 changes: 6 additions & 5 deletions cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from pytorch_grad_cam import (
GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM
)
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import (
show_cam_on_image, deprocess_image, preprocess_image
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputReST


def get_args():
Expand All @@ -37,7 +37,7 @@ def get_args():
'gradcam', 'fem', 'hirescam', 'gradcam++',
'scorecam', 'xgradcam', 'ablationcam',
'eigencam', 'eigengradcam', 'layercam',
'fullgrad', 'gradcamelementwise', 'kpcacam'
'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam'
],
help='CAM method')

Expand Down Expand Up @@ -75,7 +75,8 @@ def get_args():
"fullgrad": FullGrad,
"fem": FEM,
"gradcamelementwise": GradCAMElementWise,
'kpcacam': KPCA_CAM
'kpcacam': KPCA_CAM,
'shapleycam': ShapleyCAM
}

if args.device=='hpu':
Expand Down Expand Up @@ -109,7 +110,7 @@ def get_args():
# If targets is None, the highest scoring category (for every member in the batch) will be used.
# You can target specific categories by
# targets = [ClassifierOutputTarget(281)]
# targets = [ClassifierOutputTarget(281)]
# targets = [ClassifierOutputReST(281)]
targets = None

# Using the with statement ensures the context is freed, and you can
Expand Down
1 change: 1 addition & 0 deletions pytorch_grad_cam/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pytorch_grad_cam.grad_cam import GradCAM
from pytorch_grad_cam.shapley_cam import ShapleyCAM
from pytorch_grad_cam.fem import FEM
from pytorch_grad_cam.hirescam import HiResCAM
from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise
Expand Down
22 changes: 14 additions & 8 deletions pytorch_grad_cam/activations_and_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ class ActivationsAndGradients:
""" Class for extracting activations and
registering gradients from targetted intermediate layers """

def __init__(self, model, target_layers, reshape_transform):
def __init__(self, model, target_layers, reshape_transform, detach=True):
self.model = model
self.gradients = []
self.activations = []
self.reshape_transform = reshape_transform
self.detach = detach
self.handles = []
for target_layer in target_layers:
self.handles.append(
Expand All @@ -18,10 +19,12 @@ def __init__(self, model, target_layers, reshape_transform):

def save_activation(self, module, input, output):
activation = output

if self.reshape_transform is not None:
activation = self.reshape_transform(activation)
self.activations.append(activation.cpu().detach())
if self.detach:
if self.reshape_transform is not None:
activation = self.reshape_transform(activation)
self.activations.append(activation.cpu().detach())
else:
self.activations.append(activation)

def save_gradient(self, module, input, output):
if not hasattr(output, "requires_grad") or not output.requires_grad:
Expand All @@ -30,9 +33,12 @@ def save_gradient(self, module, input, output):

# Gradients are computed in reverse order
def _store_grad(grad):
if self.reshape_transform is not None:
grad = self.reshape_transform(grad)
self.gradients = [grad.cpu().detach()] + self.gradients
if self.detach:
if self.reshape_transform is not None:
grad = self.reshape_transform(grad)
self.gradients = [grad.cpu().detach()] + self.gradients
else:
self.gradients = [grad] + self.gradients

output.register_hook(_store_grad)

Expand Down
22 changes: 18 additions & 4 deletions pytorch_grad_cam/base_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
compute_input_gradient: bool = False,
uses_gradients: bool = True,
tta_transforms: Optional[tta.Compose] = None,
detach: bool = True,
) -> None:
self.model = model.eval()
self.target_layers = target_layers
Expand All @@ -45,7 +46,8 @@ def __init__(
else:
self.tta_transforms = tta_transforms

self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)
self.detach = detach
self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform, self.detach)

""" Get a vector of weights for every channel in the target layer.
Methods that return weights channels,
Expand All @@ -71,6 +73,8 @@ def get_cam_image(
eigen_smooth: bool = False,
) -> np.ndarray:
weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
if isinstance(activations, torch.Tensor):
activations = activations.cpu().detach().numpy()
# 2D conv
if len(activations.shape) == 4:
weighted_activations = weights[:, :, None, None] * activations
Expand Down Expand Up @@ -103,7 +107,13 @@ def forward(
if self.uses_gradients:
self.model.zero_grad()
loss = sum([target(output) for target, output in zip(targets, outputs)])
loss.backward(retain_graph=True)
if self.detach:
loss.backward(retain_graph=True)
else:
# keep the computational graph, create_graph = True is needed for hvp
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
# When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
# loss.backward(retain_graph=True, create_graph=True)
if 'hpu' in str(self.device):
self.__htcore.mark_step()

Expand Down Expand Up @@ -132,8 +142,12 @@ def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]
def compute_cam_per_layer(
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
) -> np.ndarray:
activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
if self.detach:
activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
else:
activations_list = [a for a in self.activations_and_grads.activations]
grads_list = [g for g in self.activations_and_grads.gradients]
target_size = self.get_target_width_height(input_tensor)

cam_per_target_layer = []
Expand Down
60 changes: 60 additions & 0 deletions pytorch_grad_cam/shapley_cam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Callable, List, Optional, Tuple
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import torch
import numpy as np

"""
Weights the activation maps using the gradient and Hessian-Vector product.
This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective.
"""
class ShapleyCAM(BaseCAM):
def __init__(self, model, target_layers,
reshape_transform=None):
super(
ShapleyCAM,
self).__init__(
model = model,
target_layers = target_layers,
reshape_transform = reshape_transform,
compute_input_gradient = True,
uses_gradients = True,
detach = False)

def get_cam_weights(self,
input_tensor,
target_layer,
target_category,
activations,
grads):

hvp = torch.autograd.grad(
outputs=grads,
inputs=activations,
grad_outputs=activations,
retain_graph=False,
allow_unused=True
)[0]
# print(torch.max(hvp[0]).item()) # check if hvp is not all zeros
if hvp is None:
hvp = torch.tensor(0).to(self.device)
else:
if self.activations_and_grads.reshape_transform is not None:
hvp = self.activations_and_grads.reshape_transform(hvp)

if self.activations_and_grads.reshape_transform is not None:
activations = self.activations_and_grads.reshape_transform(activations)
grads = self.activations_and_grads.reshape_transform(grads)

weight = (grads - 0.5 * hvp).detach().cpu().numpy()
# 2D image
if len(activations.shape) == 4:
weight = np.mean(weight, axis=(2, 3))
return weight
# 3D image
elif len(activations.shape) == 5:
weight = np.mean(weight, axis=(2, 3, 4))
return weight
else:
raise ValueError("Invalid grads shape."
"Shape of grads should be 4 (2D image) or 5 (3D image).")
16 changes: 16 additions & 0 deletions pytorch_grad_cam/utils/model_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@ def __call__(self, model_output):
return torch.softmax(model_output, dim=-1)[:, self.category]


class ClassifierOutputReST:
"""
Using both pre-softmax and post-softmax, proposed in https://arxiv.org/abs/2501.06261
"""
def __init__(self, category):
self.category = category
def __call__(self, model_output):
if len(model_output.shape) == 1:
target = torch.tensor([self.category], device=model_output.device)
model_output = model_output.unsqueeze(0)
return model_output[0][self.category] - torch.nn.functional.cross_entropy(model_output, target)
else:
target = torch.tensor([self.category] * model_output.shape[0], device=model_output.device)
return model_output[:,self.category] - torch.nn.functional.cross_entropy(model_output, target)


class BinaryClassifierOutputTarget:
def __init__(self, category):
self.category = category
Expand Down
Loading