Skip to content

Commit 51967a7

Browse files
authored
New CAM Method: FinerCAM (#561)
* initial * update new method finer cam * update * update * update
1 parent fd4b5c8 commit 51967a7

File tree

6 files changed

+94
-5
lines changed

6 files changed

+94
-5
lines changed

README.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
4848
| KPCA-CAM | Like EigenCAM but with Kernel PCA instead of PCA |
4949
| FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. |
5050
| ShapleyCAM | Weight the activations using the gradient and Hessian-vector product.|
51+
| FinerCAM | Improves fine-grained classification by comparing similar classes, suppressing shared features and highlighting discriminative details. |
5152
## Visual Examples
5253

5354
| What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class |
@@ -290,7 +291,7 @@ To use with a specific device, like cpu, cuda, cuda:0, mps or hpu:
290291

291292
You can choose between:
292293

293-
`GradCAM` , `HiResCAM`, `ScoreCAM`, `GradCAMPlusPlus`, `AblationCAM`, `XGradCAM` , `LayerCAM`, `FullGrad` and `EigenCAM`.
294+
`GradCAM` , `HiResCAM`, `ScoreCAM`, `GradCAMPlusPlus`, `AblationCAM`, `XGradCAM` , `LayerCAM`, `FullGrad`, `EigenCAM`, `ShapleyCAM`, and `FinerCAM`.
294295

295296
Some methods like ScoreCAM and AblationCAM require a large number of forward passes,
296297
and have a batched implementation.
@@ -368,3 +369,8 @@ Bourqui, Jenny Benois-Pineau, Akka Zemmar`
368369
https://arxiv.org/abs/2501.06261 <br>
369370
`CAMs as Shapley Value-based Explainers
370371
Huaiguang Cai`
372+
373+
374+
https://arxiv.org/pdf/2501.11309 <br>
375+
`Finer-CAM : Spotting the Difference Reveals Finer Details for Visual Explanation`
376+
`Ziheng Zhang*, Jianyang Gu*, Arpita Chowdhury, Zheda Mai, David Carlyn,Tanya Berger-Wolf, Yu Su, Wei-Lun Chao`

cam.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from pytorch_grad_cam import (
88
GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
99
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
10-
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM
10+
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM,
11+
FinerCAM
1112
)
1213
from pytorch_grad_cam import GuidedBackpropReLUModel
1314
from pytorch_grad_cam.utils.image import (
@@ -37,7 +38,8 @@ def get_args():
3738
'gradcam', 'fem', 'hirescam', 'gradcam++',
3839
'scorecam', 'xgradcam', 'ablationcam',
3940
'eigencam', 'eigengradcam', 'layercam',
40-
'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam'
41+
'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam',
42+
'finercam'
4143
],
4244
help='CAM method')
4345

@@ -76,7 +78,8 @@ def get_args():
7678
"fem": FEM,
7779
"gradcamelementwise": GradCAMElementWise,
7880
'kpcacam': KPCA_CAM,
79-
'shapleycam': ShapleyCAM
81+
'shapleycam': ShapleyCAM,
82+
'finercam': FinerCAM
8083
}
8184

8285
if args.device=='hpu':
@@ -147,4 +150,4 @@ def get_args():
147150

148151
cv2.imwrite(cam_output_path, cam_image)
149152
cv2.imwrite(gb_output_path, gb)
150-
cv2.imwrite(cam_gb_output_path, cam_gb)
153+
cv2.imwrite(cam_gb_output_path, cam_gb)

image.png

145 KB
Loading

pytorch_grad_cam/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pytorch_grad_cam.grad_cam import GradCAM
2+
from pytorch_grad_cam.finer_cam import FinerCAM
23
from pytorch_grad_cam.shapley_cam import ShapleyCAM
34
from pytorch_grad_cam.fem import FEM
45
from pytorch_grad_cam.hirescam import HiResCAM

pytorch_grad_cam/finer_cam.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numpy as np
2+
import torch
3+
from typing import List, Callable
4+
from pytorch_grad_cam.base_cam import BaseCAM
5+
from pytorch_grad_cam import GradCAM
6+
from pytorch_grad_cam.utils.model_targets import FinerWeightedTarget
7+
8+
class FinerCAM:
9+
def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], reshape_transform: Callable = None, base_method=GradCAM):
10+
self.base_cam = base_method(model, target_layers, reshape_transform)
11+
self.compute_input_gradient = self.base_cam.compute_input_gradient
12+
self.uses_gradients = self.base_cam.uses_gradients
13+
14+
def __call__(self, *args, **kwargs):
15+
return self.forward(*args, **kwargs)
16+
17+
def forward(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module] = None, eigen_smooth: bool = False,
18+
alpha: float = 1, comparison_categories: List[int] = [1, 2, 3], target_idx: int = None
19+
) -> np.ndarray:
20+
input_tensor = input_tensor.to(self.base_cam.device)
21+
22+
if self.compute_input_gradient:
23+
input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
24+
25+
outputs = self.base_cam.activations_and_grads(input_tensor)
26+
27+
if targets is None:
28+
output_data = outputs.detach().cpu().numpy()
29+
target_logits = np.max(output_data, axis=-1) if target_idx is None else output_data[:, target_idx]
30+
# Sort class indices for each sample based on the absolute difference
31+
# between the class scores and the target logit, in ascending order.
32+
# The most similar classes (smallest difference) appear first.
33+
sorted_indices = np.argsort(np.abs(output_data - target_logits[:, None]), axis=-1)
34+
targets = [FinerWeightedTarget(int(sorted_indices[i, 0]),
35+
[int(sorted_indices[i, idx]) for idx in comparison_categories],
36+
alpha)
37+
for i in range(output_data.shape[0])]
38+
39+
if self.uses_gradients:
40+
self.base_cam.model.zero_grad()
41+
loss = sum([target(output) for target, output in zip(targets, outputs)])
42+
if self.base_cam.detach:
43+
loss.backward(retain_graph=True)
44+
else:
45+
# keep the computational graph, create_graph = True is needed for hvp
46+
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
47+
# When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
48+
# loss.backward(retain_graph=True, create_graph=True)
49+
if 'hpu' in str(self.base_cam.device):
50+
self.base_cam.__htcore.mark_step()
51+
52+
cam_per_layer = self.base_cam.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
53+
return self.base_cam.aggregate_multi_layers(cam_per_layer)

pytorch_grad_cam/utils/model_targets.py

+26
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,29 @@ def __call__(self, model_outputs):
119119
score = ious[0, index] + model_outputs["scores"][index]
120120
output = output + score
121121
return output
122+
123+
class FinerWeightedTarget:
124+
"""
125+
Computes a weighted difference between a primary category and a set of comparison categories.
126+
127+
This target calculates the difference between the score for the main category and each of the comparison categories.
128+
It obtains a weight for each comparison category from the softmax probabilities of the model output and computes a
129+
weighted difference scaled by a comparison strength factor alpha.
130+
"""
131+
def __init__(self, main_category, comparison_categories, alpha):
132+
self.main_category = main_category
133+
self.comparison_categories = comparison_categories
134+
self.alpha = alpha
135+
136+
def __call__(self, model_output):
137+
select = lambda idx: model_output[idx] if model_output.ndim == 1 else model_output[..., idx]
138+
139+
wn = select(self.main_category)
140+
141+
prob = torch.softmax(model_output, dim=-1)
142+
143+
weights = [prob[idx] if model_output.ndim == 1 else prob[..., idx] for idx in self.comparison_categories]
144+
numerator = sum(w * (wn - self.alpha * select(idx)) for w, idx in zip(weights, self.comparison_categories))
145+
denominator = sum(weights)
146+
147+
return numerator / (denominator + 1e-9)

0 commit comments

Comments
 (0)