We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
SemanticSegmentationTarget
1 parent 0d86e52 commit 5cef718Copy full SHA for 5cef718
pytorch_grad_cam/utils/model_targets.py
@@ -59,13 +59,9 @@ class SemanticSegmentationTarget:
59
def __init__(self, category, mask):
60
self.category = category
61
self.mask = torch.from_numpy(mask)
62
- if torch.cuda.is_available():
63
- self.mask = self.mask.cuda()
64
- if torch.backends.mps.is_available():
65
- self.mask = self.mask.to("mps")
66
67
def __call__(self, model_output):
68
- return (model_output[self.category, :, :] * self.mask).sum()
+ return (model_output[self.category, :, :] * self.mask.to(model_output.device)).sum()
69
70
71
class FasterRCNNBoxScoreTarget:
0 commit comments