Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New CAM Method: FinerCAM #561

Merged
merged 5 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
| KPCA-CAM | Like EigenCAM but with Kernel PCA instead of PCA |
| FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. |
| ShapleyCAM | Weight the activations using the gradient and Hessian-vector product.|
| FinerCAM | Improves fine-grained classification by comparing similar classes, suppressing shared features and highlighting discriminative details. |
## Visual Examples

| What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class |
Expand Down Expand Up @@ -290,7 +291,7 @@ To use with a specific device, like cpu, cuda, cuda:0, mps or hpu:

You can choose between:

`GradCAM` , `HiResCAM`, `ScoreCAM`, `GradCAMPlusPlus`, `AblationCAM`, `XGradCAM` , `LayerCAM`, `FullGrad` and `EigenCAM`.
`GradCAM` , `HiResCAM`, `ScoreCAM`, `GradCAMPlusPlus`, `AblationCAM`, `XGradCAM` , `LayerCAM`, `FullGrad`, `EigenCAM`, `ShapleyCAM`, and `FinerCAM`.

Some methods like ScoreCAM and AblationCAM require a large number of forward passes,
and have a batched implementation.
Expand Down Expand Up @@ -368,3 +369,8 @@ Bourqui, Jenny Benois-Pineau, Akka Zemmar`
https://arxiv.org/abs/2501.06261 <br>
`CAMs as Shapley Value-based Explainers
Huaiguang Cai`


https://arxiv.org/pdf/2501.11309 <br>
`Finer-CAM : Spotting the Difference Reveals Finer Details for Visual Explanation`
`Ziheng Zhang*, Jianyang Gu*, Arpita Chowdhury, Zheda Mai, David Carlyn,Tanya Berger-Wolf, Yu Su, Wei-Lun Chao`
11 changes: 7 additions & 4 deletions cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from pytorch_grad_cam import (
GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM,
FinerCAM
)
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import (
Expand Down Expand Up @@ -37,7 +38,8 @@ def get_args():
'gradcam', 'fem', 'hirescam', 'gradcam++',
'scorecam', 'xgradcam', 'ablationcam',
'eigencam', 'eigengradcam', 'layercam',
'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam'
'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam',
'finercam'
],
help='CAM method')

Expand Down Expand Up @@ -76,7 +78,8 @@ def get_args():
"fem": FEM,
"gradcamelementwise": GradCAMElementWise,
'kpcacam': KPCA_CAM,
'shapleycam': ShapleyCAM
'shapleycam': ShapleyCAM,
'finercam': FinerCAM
}

if args.device=='hpu':
Expand Down Expand Up @@ -147,4 +150,4 @@ def get_args():

cv2.imwrite(cam_output_path, cam_image)
cv2.imwrite(gb_output_path, gb)
cv2.imwrite(cam_gb_output_path, cam_gb)
cv2.imwrite(cam_gb_output_path, cam_gb)
Binary file added image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions pytorch_grad_cam/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pytorch_grad_cam.grad_cam import GradCAM
from pytorch_grad_cam.finer_cam import FinerCAM
from pytorch_grad_cam.shapley_cam import ShapleyCAM
from pytorch_grad_cam.fem import FEM
from pytorch_grad_cam.hirescam import HiResCAM
Expand Down
53 changes: 53 additions & 0 deletions pytorch_grad_cam/finer_cam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import torch
from typing import List, Callable
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import FinerWeightedTarget

class FinerCAM:
def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], reshape_transform: Callable = None, base_method=GradCAM):
self.base_cam = base_method(model, target_layers, reshape_transform)
self.compute_input_gradient = self.base_cam.compute_input_gradient
self.uses_gradients = self.base_cam.uses_gradients

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def forward(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module] = None, eigen_smooth: bool = False,
alpha: float = 1, comparison_categories: List[int] = [1, 2, 3], target_idx: int = None
) -> np.ndarray:
input_tensor = input_tensor.to(self.base_cam.device)

if self.compute_input_gradient:
input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)

outputs = self.base_cam.activations_and_grads(input_tensor)

if targets is None:
output_data = outputs.detach().cpu().numpy()
target_logits = np.max(output_data, axis=-1) if target_idx is None else output_data[:, target_idx]
# Sort class indices for each sample based on the absolute difference
# between the class scores and the target logit, in ascending order.
# The most similar classes (smallest difference) appear first.
sorted_indices = np.argsort(np.abs(output_data - target_logits[:, None]), axis=-1)
targets = [FinerWeightedTarget(int(sorted_indices[i, 0]),
[int(sorted_indices[i, idx]) for idx in comparison_categories],
alpha)
for i in range(output_data.shape[0])]

if self.uses_gradients:
self.base_cam.model.zero_grad()
loss = sum([target(output) for target, output in zip(targets, outputs)])
if self.base_cam.detach:
loss.backward(retain_graph=True)
else:
# keep the computational graph, create_graph = True is needed for hvp
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
# When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
# loss.backward(retain_graph=True, create_graph=True)
if 'hpu' in str(self.base_cam.device):
self.base_cam.__htcore.mark_step()

cam_per_layer = self.base_cam.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
return self.base_cam.aggregate_multi_layers(cam_per_layer)
26 changes: 26 additions & 0 deletions pytorch_grad_cam/utils/model_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,29 @@ def __call__(self, model_outputs):
score = ious[0, index] + model_outputs["scores"][index]
output = output + score
return output

class FinerWeightedTarget:
"""
Computes a weighted difference between a primary category and a set of comparison categories.

This target calculates the difference between the score for the main category and each of the comparison categories.
It obtains a weight for each comparison category from the softmax probabilities of the model output and computes a
weighted difference scaled by a comparison strength factor alpha.
"""
def __init__(self, main_category, comparison_categories, alpha):
self.main_category = main_category
self.comparison_categories = comparison_categories
self.alpha = alpha

def __call__(self, model_output):
select = lambda idx: model_output[idx] if model_output.ndim == 1 else model_output[..., idx]

wn = select(self.main_category)

prob = torch.softmax(model_output, dim=-1)

weights = [prob[idx] if model_output.ndim == 1 else prob[..., idx] for idx in self.comparison_categories]
numerator = sum(w * (wn - self.alpha * select(idx)) for w, idx in zip(weights, self.comparison_categories))
denominator = sum(weights)

return numerator / (denominator + 1e-9)