Skip to content

Commit 3018ba3

Browse files
authored
add clip example (#527)
* add clip example * add clip example
1 parent 48a3ae1 commit 3018ba3

File tree

2 files changed

+170
-1
lines changed

2 files changed

+170
-1
lines changed

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ ttach
66
tqdm
77
opencv-python
88
matplotlib
9-
scikit-learn
9+
scikit-learn
10+
transformers

usage_examples/clip_example

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import argparse
2+
3+
import cv2
4+
import numpy as np
5+
import torch
6+
from torch import nn
7+
from transformers import CLIPProcessor, CLIPModel
8+
9+
10+
from pytorch_grad_cam import GradCAM, \
11+
ScoreCAM, \
12+
GradCAMPlusPlus, \
13+
AblationCAM, \
14+
XGradCAM, \
15+
EigenCAM, \
16+
EigenGradCAM, \
17+
LayerCAM, \
18+
FullGrad
19+
20+
from pytorch_grad_cam.utils.image import show_cam_on_image, \
21+
preprocess_image
22+
from pytorch_grad_cam.ablation_layer import AblationLayerVit
23+
24+
25+
def get_args():
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument('--use-cuda', action='store_true', default=False,
28+
help='Use NVIDIA GPU acceleration')
29+
parser.add_argument(
30+
'--image-path',
31+
type=str,
32+
default='./examples/both.png',
33+
help='Input image path')
34+
parser.add_argument(
35+
'--labels',
36+
type=str,
37+
nargs='+',
38+
default=["a cat", "a dog", "a car", "a person", "a shoe"],
39+
help='need recognition labels'
40+
)
41+
42+
parser.add_argument('--aug_smooth', action='store_true',
43+
help='Apply test time augmentation to smooth the CAM')
44+
parser.add_argument(
45+
'--eigen_smooth',
46+
action='store_true',
47+
help='Reduce noise by taking the first principle componenet'
48+
'of cam_weights*activations')
49+
50+
parser.add_argument(
51+
'--method',
52+
type=str,
53+
default='gradcam',
54+
help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam')
55+
56+
args = parser.parse_args()
57+
args.use_cuda = args.use_cuda and torch.cuda.is_available()
58+
if args.use_cuda:
59+
print('Using GPU for acceleration')
60+
else:
61+
print('Using CPU for computation')
62+
63+
return args
64+
65+
66+
def reshape_transform(tensor, height=16, width=16):
67+
result = tensor[:, 1:, :].reshape(tensor.size(0),
68+
height, width, tensor.size(2))
69+
70+
# Bring the channels to the first dimension,
71+
# like in CNNs.
72+
result = result.transpose(2, 3).transpose(1, 2)
73+
return result
74+
75+
76+
class ImageClassifier(nn.Module):
77+
def __init__(self, labels):
78+
super(ImageClassifier, self).__init__()
79+
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
80+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
81+
self.labels = labels
82+
83+
def forward(self, x):
84+
text_inputs = self.processor(text=labels, return_tensors="pt", padding=True)
85+
86+
outputs = self.clip(pixel_values=x, input_ids=text_inputs['input_ids'], attention_mask=text_inputs['attention_mask'])
87+
88+
logits_per_image = outputs.logits_per_image
89+
probs = logits_per_image.softmax(dim=1)
90+
91+
for label, prob in zip(self.labels, probs[0]):
92+
print(f"{label}: {prob:.4f}")
93+
return probs
94+
95+
96+
if __name__ == '__main__':
97+
""" python vit_gradcam.py --image-path <path_to_image>
98+
Example usage of using cam-methods on a VIT network.
99+
100+
"""
101+
102+
args = get_args()
103+
methods = \
104+
{"gradcam": GradCAM,
105+
"scorecam": ScoreCAM,
106+
"gradcam++": GradCAMPlusPlus,
107+
"ablationcam": AblationCAM,
108+
"xgradcam": XGradCAM,
109+
"eigencam": EigenCAM,
110+
"eigengradcam": EigenGradCAM,
111+
"layercam": LayerCAM,
112+
"fullgrad": FullGrad}
113+
114+
if args.method not in list(methods.keys()):
115+
raise Exception(f"method should be one of {list(methods.keys())}")
116+
117+
labels = args.labels
118+
model = ImageClassifier(labels)
119+
if args.use_cuda:
120+
model.cuda()
121+
model.eval()
122+
print(model)
123+
124+
target_layers = [model.clip.vision_model.encoder.layers[-1].layer_norm1]
125+
126+
if args.method not in methods:
127+
raise Exception(f"Method {args.method} not implemented")
128+
129+
if args.use_cuda:
130+
model = model.cuda()
131+
132+
rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
133+
rgb_img = cv2.resize(rgb_img, (224, 224))
134+
rgb_img = np.float32(rgb_img) / 255
135+
input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5],
136+
std=[0.5, 0.5, 0.5])
137+
138+
if args.method == "ablationcam":
139+
cam = methods[args.method](model=model,
140+
target_layers=target_layers,
141+
reshape_transform=reshape_transform,
142+
ablation_layer=AblationLayerVit())
143+
else:
144+
cam = methods[args.method](model=model,
145+
target_layers=target_layers,
146+
reshape_transform=reshape_transform)
147+
148+
149+
150+
# If None, returns the map for the highest scoring category.
151+
# Otherwise, targets the requested category.
152+
targets = None
153+
print(input_tensor.shape)
154+
155+
# AblationCAM and ScoreCAM have batched implementations.
156+
# You can override the internal batch size for faster computation.
157+
cam.batch_size = 32
158+
159+
grayscale_cam = cam(input_tensor=input_tensor,
160+
targets=targets,
161+
eigen_smooth=args.eigen_smooth,
162+
aug_smooth=args.aug_smooth)
163+
164+
# Here grayscale_cam has only one image in the batch
165+
grayscale_cam = grayscale_cam[0, :]
166+
167+
cam_image = show_cam_on_image(rgb_img, grayscale_cam)
168+
cv2.imwrite(f'{args.method}_cam.jpg', cam_image)

0 commit comments

Comments
 (0)