Skip to content

Commit 0d86e52

Browse files
committed
clip
1 parent ded280c commit 0d86e52

File tree

1 file changed

+167
-0
lines changed

1 file changed

+167
-0
lines changed

usage_examples/clip_example.py

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import argparse
2+
3+
import cv2
4+
import numpy as np
5+
import torch
6+
from torch import nn
7+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
8+
try:
9+
from transformers import CLIPProcessor, CLIPModel
10+
except ImportError:
11+
print("The transformers package is not installed. Please install it to use CLIP.")
12+
exit(1)
13+
14+
15+
from pytorch_grad_cam import GradCAM, \
16+
ScoreCAM, \
17+
GradCAMPlusPlus, \
18+
AblationCAM, \
19+
XGradCAM, \
20+
EigenCAM, \
21+
EigenGradCAM, \
22+
LayerCAM, \
23+
FullGrad
24+
25+
from pytorch_grad_cam.utils.image import show_cam_on_image, \
26+
preprocess_image
27+
from pytorch_grad_cam.ablation_layer import AblationLayerVit
28+
29+
30+
def get_args():
31+
parser = argparse.ArgumentParser()
32+
parser.add_argument('--device', type=str, default='cpu',
33+
help='Torch device to use')
34+
parser.add_argument(
35+
'--image-path',
36+
type=str,
37+
default='./examples/both.png',
38+
help='Input image path')
39+
parser.add_argument(
40+
'--labels',
41+
type=str,
42+
nargs='+',
43+
default=["a cat", "a dog", "a car", "a person", "a shoe"],
44+
help='need recognition labels'
45+
)
46+
47+
parser.add_argument('--aug_smooth', action='store_true',
48+
help='Apply test time augmentation to smooth the CAM')
49+
parser.add_argument(
50+
'--eigen_smooth',
51+
action='store_true',
52+
help='Reduce noise by taking the first principle componenet'
53+
'of cam_weights*activations')
54+
55+
parser.add_argument(
56+
'--method',
57+
type=str,
58+
default='gradcam',
59+
help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam')
60+
61+
args = parser.parse_args()
62+
if args.device:
63+
print(f'Using device "{args.device}" for acceleration')
64+
else:
65+
print('Using CPU for computation')
66+
67+
return args
68+
69+
70+
def reshape_transform(tensor, height=16, width=16):
71+
result = tensor[:, 1:, :].reshape(tensor.size(0),
72+
height, width, tensor.size(2))
73+
74+
# Bring the channels to the first dimension,
75+
# like in CNNs.
76+
result = result.transpose(2, 3).transpose(1, 2)
77+
return result
78+
79+
80+
class ImageClassifier(nn.Module):
81+
def __init__(self, labels):
82+
super(ImageClassifier, self).__init__()
83+
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
84+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
85+
self.labels = labels
86+
87+
def forward(self, x):
88+
text_inputs = self.processor(text=self.labels, return_tensors="pt", padding=True)
89+
90+
outputs = self.clip(pixel_values=x, input_ids=text_inputs['input_ids'].to(self.clip.device),
91+
attention_mask=text_inputs['attention_mask'].to(self.clip.device))
92+
93+
logits_per_image = outputs.logits_per_image
94+
probs = logits_per_image.softmax(dim=1)
95+
96+
for label, prob in zip(self.labels, probs[0]):
97+
print(f"{label}: {prob:.4f}")
98+
return probs
99+
100+
101+
if __name__ == '__main__':
102+
""" python vit_gradcam.py --image-path <path_to_image>
103+
Example usage of using cam-methods on a VIT network.
104+
105+
"""
106+
107+
args = get_args()
108+
methods = \
109+
{"gradcam": GradCAM,
110+
"scorecam": ScoreCAM,
111+
"gradcam++": GradCAMPlusPlus,
112+
"ablationcam": AblationCAM,
113+
"xgradcam": XGradCAM,
114+
"eigencam": EigenCAM,
115+
"eigengradcam": EigenGradCAM,
116+
"layercam": LayerCAM,
117+
"fullgrad": FullGrad}
118+
119+
if args.method not in list(methods.keys()):
120+
raise Exception(f"method should be one of {list(methods.keys())}")
121+
122+
labels = args.labels
123+
model = ImageClassifier(labels).to(torch.device(args.device)).eval()
124+
print(model)
125+
126+
target_layers = [model.clip.vision_model.encoder.layers[-1].layer_norm1]
127+
128+
if args.method not in methods:
129+
raise Exception(f"Method {args.method} not implemented")
130+
131+
rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
132+
rgb_img = cv2.resize(rgb_img, (224, 224))
133+
rgb_img = np.float32(rgb_img) / 255
134+
input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5],
135+
std=[0.5, 0.5, 0.5]).to(args.device)
136+
137+
if args.method == "ablationcam":
138+
cam = methods[args.method](model=model,
139+
target_layers=target_layers,
140+
reshape_transform=reshape_transform,
141+
ablation_layer=AblationLayerVit())
142+
else:
143+
cam = methods[args.method](model=model,
144+
target_layers=target_layers,
145+
reshape_transform=reshape_transform)
146+
147+
148+
149+
# If None, returns the map for the highest scoring category.
150+
# Otherwise, targets the requested category.
151+
#targets = [ClassifierOutputTarget(1)]
152+
targets = None
153+
154+
# AblationCAM and ScoreCAM have batched implementations.
155+
# You can override the internal batch size for faster computation.
156+
cam.batch_size = 32
157+
158+
grayscale_cam = cam(input_tensor=input_tensor,
159+
targets=targets,
160+
eigen_smooth=args.eigen_smooth,
161+
aug_smooth=args.aug_smooth)
162+
163+
# Here grayscale_cam has only one image in the batch
164+
grayscale_cam = grayscale_cam[0, :]
165+
166+
cam_image = show_cam_on_image(rgb_img, grayscale_cam)
167+
cv2.imwrite(f'{args.method}_cam.jpg', cam_image)

0 commit comments

Comments
 (0)