Skip to content

Commit 48a3ae1

Browse files
New CAM technique: KPCA-CAM (#534)
* Update svd_on_activations.py added kernel pca * Create kpca_cam.py * Update svd_on_activations.py
1 parent 18144f2 commit 48a3ae1

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

pytorch_grad_cam/kpca_cam.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from pytorch_grad_cam.base_cam import BaseCAM
2+
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection_kernel
3+
4+
class KPCA_CAM(BaseCAM):
5+
def __init__(self, model, target_layers,
6+
reshape_transform=None, kernel='sigmoid', gamma=None):
7+
super(KPCA_CAM, self).__init__(model,
8+
target_layers,
9+
reshape_transform,
10+
uses_gradients=False,
11+
kernel=kernel, gamma=gamma)
12+
13+
def get_cam_image(self,
14+
input_tensor,
15+
target_layer,
16+
target_category,
17+
activations,
18+
grads,
19+
eigen_smooth):
20+
return get_2d_projection_kernel(activations, self.kernel, self.gamma)

pytorch_grad_cam/utils/svd_on_activations.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from sklearn.decomposition import KernelPCA
23

34

45
def get_2d_projection(activation_batch):
@@ -17,3 +18,18 @@ def get_2d_projection(activation_batch):
1718
projection = projection.reshape(activations.shape[1:])
1819
projections.append(projection)
1920
return np.float32(projections)
21+
22+
23+
24+
def get_2d_projection_kernel(activation_batch, kernel='sigmoid', gamma=None):
25+
activation_batch[np.isnan(activation_batch)] = 0
26+
projections = []
27+
for activations in activation_batch:
28+
reshaped_activations = activations.reshape(activations.shape[0], -1).transpose()
29+
reshaped_activations = reshaped_activations - reshaped_activations.mean(axis=0)
30+
# Apply Kernel PCA
31+
kpca = KernelPCA(n_components=1, kernel=kernel, gamma=gamma)
32+
projection = kpca.fit_transform(reshaped_activations)
33+
projection = projection.reshape(activations.shape[1:])
34+
projections.append(projection)
35+
return np.float32(projections)

0 commit comments

Comments
 (0)