@@ -145,6 +145,9 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
145145 Algorithm parameter, K2 (small constant, see [3]_).
146146 Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN
147147 results.
148+ device : Union[str, torch.device], default 'cpu'
149+ Determines the device if the image is a PyTorch tensor,
150+ e.g. "cuda", "cpu", "cuda:0", ...
148151
149152 .. seealso::
150153 See :py:func:`.viqa.fr_metrics.ssim.structural_similarity` for more
@@ -202,6 +205,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
202205 img_r [im_slice , :, :],
203206 img_m [im_slice , :, :],
204207 self .parameters ["chromatic" ],
208+ self .parameters ["device" ],
205209 )
206210 score_val = multi_scale_ssim (
207211 img_r_tensor ,
@@ -214,6 +218,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
214218 img_r [:, im_slice , :],
215219 img_m [:, im_slice , :],
216220 self .parameters ["chromatic" ],
221+ self .parameters ["device" ],
217222 )
218223 score_val = multi_scale_ssim (
219224 img_r_tensor ,
@@ -226,6 +231,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
226231 img_r [:, :, im_slice ],
227232 img_m [:, :, im_slice ],
228233 self .parameters ["chromatic" ],
234+ self .parameters ["device" ],
229235 )
230236 score_val = multi_scale_ssim (
231237 img_r_tensor ,
@@ -249,6 +255,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
249255 img_r ,
250256 img_m ,
251257 self .parameters ["chromatic" ],
258+ self .parameters ["device" ],
252259 )
253260 score_val = multi_scale_ssim (
254261 img_r_tensor ,
@@ -270,6 +277,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
270277 img_r ,
271278 img_m ,
272279 self .parameters ["chromatic" ],
280+ self .parameters ["device" ],
273281 )
274282 score_val = multi_scale_ssim (
275283 img_r_tensor ,
0 commit comments