Skip to content

Commit a2a23f8

Browse files
Add Gaudi hpu accelerator option to BaseCAM (#547)
* add gaudi hpu option Signed-off-by: Daniel Deleon <[email protected]> * add try except block Signed-off-by: Daniel Deleon <[email protected]> --------- Signed-off-by: Daniel Deleon <[email protected]>
1 parent 5cef718 commit a2a23f8

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ two smoothing methods are supported:
281281
Usage: `python cam.py --image-path <path_to_image> --method <method> --output-dir <output_dir_path> `
282282

283283

284-
To use with a specific device, like cpu, cuda, cuda:0 or mps:
284+
To use with a specific device, like cpu, cuda, cuda:0, mps or hpu:
285285
`python cam.py --image-path <path_to_image> --device cuda --output-dir <output_dir_path> `
286286

287287
----------

cam.py

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def get_args():
7777
'kpcacam': KPCA_CAM
7878
}
7979

80+
if args.device=='hpu':
81+
import habana_frameworks.torch.core as htcore
82+
8083
model = models.resnet50(pretrained=True).to(torch.device(args.device)).eval()
8184

8285
# Choose the target layer you want to compute the visualization for.

pytorch_grad_cam/base_cam.py

+9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ def __init__(
2525

2626
# Use the same device as the model.
2727
self.device = next(self.model.parameters()).device
28+
if 'hpu' in str(self.device):
29+
try:
30+
import habana_frameworks.torch.core as htcore
31+
except ImportError as error:
32+
error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}."
33+
raise error
34+
self.__htcore = htcore
2835
self.reshape_transform = reshape_transform
2936
self.compute_input_gradient = compute_input_gradient
3037
self.uses_gradients = uses_gradients
@@ -97,6 +104,8 @@ def forward(
97104
self.model.zero_grad()
98105
loss = sum([target(output) for target, output in zip(targets, outputs)])
99106
loss.backward(retain_graph=True)
107+
if 'hpu' in str(self.device):
108+
self.__htcore.mark_step()
100109

101110
# In most of the saliency attribution papers, the saliency is
102111
# computed with a single target layer.

0 commit comments

Comments
 (0)