Skip to content

Commit 5cef718

Browse files
Dynamically move mask to device in SemanticSegmentationTarget (#546)
Signed-off-by: Shreyas Ranganatha <[email protected]>
1 parent 0d86e52 commit 5cef718

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

pytorch_grad_cam/utils/model_targets.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,9 @@ class SemanticSegmentationTarget:
5959
def __init__(self, category, mask):
6060
self.category = category
6161
self.mask = torch.from_numpy(mask)
62-
if torch.cuda.is_available():
63-
self.mask = self.mask.cuda()
64-
if torch.backends.mps.is_available():
65-
self.mask = self.mask.to("mps")
6662

6763
def __call__(self, model_output):
68-
return (model_output[self.category, :, :] * self.mask).sum()
64+
return (model_output[self.category, :, :] * self.mask.to(model_output.device)).sum()
6965

7066

7167
class FasterRCNNBoxScoreTarget:

0 commit comments

Comments
 (0)