Skip to content

Commit ed3bcab

Browse files
authored
Merge pull request #17 from Nepelius/dev
feat: Device option for image loading when the image is a Pytorch Tensor
2 parents 3be6d6b + 2592d37 commit ed3bcab

File tree

6 files changed

+44
-8
lines changed

6 files changed

+44
-8
lines changed

src/viqa/_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, data_range, normalize, **kwargs):
3030
"normalize": normalize,
3131
"chromatic": False,
3232
"roi": None,
33-
**kwargs,
33+
"device": "cpu" ** kwargs,
3434
}
3535
self.score_val = None
3636
self._name = None

src/viqa/fr_metrics/fsim.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ class FSIM(FullReferenceMetricsInterface):
6666
If True, the input images are expected to be RGB images and FSIMc is
6767
calculated. See [1]_. Passed to
6868
:py:func:`piq.fsim`. See the documentation under [2]_.
69+
device : Union[str, torch.device], default 'cpu'
70+
Determines the device if the image is a PyTorch tensor,
71+
e.g. "cuda", "cpu", "cuda:0", ...
6972
7073
Raises
7174
------
@@ -197,6 +200,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
197200
img_r[im_slice, :, :],
198201
img_m[im_slice, :, :],
199202
self.parameters["chromatic"],
203+
self.parameters["device"],
200204
)
201205
score_val = fsim(
202206
img_r_tensor,
@@ -210,6 +214,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
210214
img_r[:, im_slice, :],
211215
img_m[:, im_slice, :],
212216
self.parameters["chromatic"],
217+
self.parameters["device"],
213218
)
214219
score_val = fsim(
215220
img_r_tensor,
@@ -223,6 +228,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
223228
img_r[:, :, im_slice],
224229
img_m[:, :, im_slice],
225230
self.parameters["chromatic"],
231+
self.parameters["device"],
226232
)
227233
score_val = fsim(
228234
img_r_tensor,
@@ -246,6 +252,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
246252
img_r,
247253
img_m,
248254
self.parameters["chromatic"],
255+
self.parameters["device"],
249256
)
250257
score_val = fsim(
251258
img_r_tensor,
@@ -268,6 +275,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
268275
img_r,
269276
img_m,
270277
self.parameters["chromatic"],
278+
self.parameters["device"],
271279
)
272280
score_val = fsim(
273281
img_r_tensor,

src/viqa/fr_metrics/msssim.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
145145
Algorithm parameter, K2 (small constant, see [3]_).
146146
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN
147147
results.
148+
device : Union[str, torch.device], default 'cpu'
149+
Determines the device if the image is a PyTorch tensor,
150+
e.g. "cuda", "cpu", "cuda:0", ...
148151
149152
.. seealso::
150153
See :py:func:`.viqa.fr_metrics.ssim.structural_similarity` for more
@@ -202,6 +205,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
202205
img_r[im_slice, :, :],
203206
img_m[im_slice, :, :],
204207
self.parameters["chromatic"],
208+
self.parameters["device"],
205209
)
206210
score_val = multi_scale_ssim(
207211
img_r_tensor,
@@ -214,6 +218,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
214218
img_r[:, im_slice, :],
215219
img_m[:, im_slice, :],
216220
self.parameters["chromatic"],
221+
self.parameters["device"],
217222
)
218223
score_val = multi_scale_ssim(
219224
img_r_tensor,
@@ -226,6 +231,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
226231
img_r[:, :, im_slice],
227232
img_m[:, :, im_slice],
228233
self.parameters["chromatic"],
234+
self.parameters["device"],
229235
)
230236
score_val = multi_scale_ssim(
231237
img_r_tensor,
@@ -249,6 +255,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
249255
img_r,
250256
img_m,
251257
self.parameters["chromatic"],
258+
self.parameters["device"],
252259
)
253260
score_val = multi_scale_ssim(
254261
img_r_tensor,
@@ -270,6 +277,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
270277
img_r,
271278
img_m,
272279
self.parameters["chromatic"],
280+
self.parameters["device"],
273281
)
274282
score_val = multi_scale_ssim(
275283
img_r_tensor,

src/viqa/fr_metrics/vif.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
127127
HVS model parameter (variance of the visual noise). See [3]_.
128128
reduction : str, default='mean'
129129
Specifies the reduction type: 'none', 'mean' or 'sum'.
130+
device : Union[str, torch.device], default 'cpu'
131+
Determines the device if the image is a PyTorch tensor,
132+
e.g. "cuda", "cpu", "cuda:0", ...
130133
131134
Returns
132135
-------
@@ -176,6 +179,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
176179
img_r[im_slice, :, :],
177180
img_m[im_slice, :, :],
178181
self.parameters["chromatic"],
182+
self.parameters["device"],
179183
)
180184
score_val = vif_p(
181185
img_r_tensor,
@@ -188,6 +192,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
188192
img_r[:, im_slice, :],
189193
img_m[:, im_slice, :],
190194
self.parameters["chromatic"],
195+
self.parameters["device"],
191196
)
192197
score_val = vif_p(
193198
img_r_tensor,
@@ -200,6 +205,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
200205
img_r[:, :, im_slice],
201206
img_m[:, :, im_slice],
202207
self.parameters["chromatic"],
208+
self.parameters["device"],
203209
)
204210
score_val = vif_p(
205211
img_r_tensor,
@@ -222,6 +228,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
222228
img_r,
223229
img_m,
224230
self.parameters["chromatic"],
231+
self.parameters["device"],
225232
)
226233
score_val = vif_p(
227234
img_r_tensor,
@@ -243,6 +250,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
243250
img_r,
244251
img_m,
245252
self.parameters["chromatic"],
253+
self.parameters["device"],
246254
)
247255
score_val = vif_p(
248256
img_r_tensor,

src/viqa/fr_metrics/vsi.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ class VSI(FullReferenceMetricsInterface):
6464
----------------
6565
chromatic : bool, default False
6666
If True, the input images are expected to be RGB images.
67+
device : Union[str, torch.device], default 'cpu'
68+
Determines the device if the image is a PyTorch tensor,
69+
e.g. "cuda", "cpu", "cuda:0", ...
6770
6871
Raises
6972
------
@@ -197,6 +200,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
197200
img_r[im_slice, :, :],
198201
img_m[im_slice, :, :],
199202
self.parameters["chromatic"],
203+
self.parameters["device"],
200204
)
201205
score_val = vsi(
202206
img_r_tensor,
@@ -209,6 +213,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
209213
img_r[:, im_slice, :],
210214
img_m[:, im_slice, :],
211215
self.parameters["chromatic"],
216+
self.parameters["device"],
212217
)
213218
score_val = vsi(
214219
img_r_tensor,
@@ -221,6 +226,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
221226
img_r[:, :, im_slice],
222227
img_m[:, :, im_slice],
223228
self.parameters["chromatic"],
229+
self.parameters["device"],
224230
)
225231
score_val = vsi(
226232
img_r_tensor,
@@ -243,6 +249,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
243249
img_r,
244250
img_m,
245251
self.parameters["chromatic"],
252+
self.parameters["device"],
246253
)
247254
score_val = vsi(
248255
img_r_tensor,
@@ -264,6 +271,7 @@ def score(self, img_r, img_m, dim=None, im_slice=None, **kwargs):
264271
img_r,
265272
img_m,
266273
self.parameters["chromatic"],
274+
self.parameters["device"],
267275
)
268276
score_val = vsi(
269277
img_r_tensor,

src/viqa/utils/misc.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -417,25 +417,29 @@ def _get_range(cols_rows):
417417
return res
418418

419419

420-
def _check_chromatic(img_r, img_m, chromatic):
420+
def _check_chromatic(img_r, img_m, chromatic, device):
421421
"""Permute image based on dimensions and chromaticity."""
422422
img_r = _to_float(img_r, np.float32)
423423
img_m = _to_float(img_m, np.float32)
424424
# check if chromatic
425425
if chromatic is False:
426426
if img_r.ndim == 3:
427427
# 3D images
428-
img_r_tensor = torch.tensor(img_r).unsqueeze(0).permute(3, 0, 1, 2)
429-
img_m_tensor = torch.tensor(img_m).unsqueeze(0).permute(3, 0, 1, 2)
428+
img_r_tensor = (
429+
torch.tensor(img_r).unsqueeze(0).permute(3, 0, 1, 2).to(device)
430+
)
431+
img_m_tensor = (
432+
torch.tensor(img_m).unsqueeze(0).permute(3, 0, 1, 2).to(device)
433+
)
430434
elif img_r.ndim == 2:
431435
# 2D images
432-
img_r_tensor = torch.tensor(img_r).unsqueeze(0).unsqueeze(0)
433-
img_m_tensor = torch.tensor(img_m).unsqueeze(0).unsqueeze(0)
436+
img_r_tensor = torch.tensor(img_r).unsqueeze(0).unsqueeze(0).to(device)
437+
img_m_tensor = torch.tensor(img_m).unsqueeze(0).unsqueeze(0).to(device)
434438
else:
435439
raise ValueError("Image format not supported.")
436440
else:
437-
img_r_tensor = torch.tensor(img_r).permute(2, 0, 1).unsqueeze(0)
438-
img_m_tensor = torch.tensor(img_m).permute(2, 0, 1).unsqueeze(0)
441+
img_r_tensor = torch.tensor(img_r).permute(2, 0, 1).unsqueeze(0).to(device)
442+
img_m_tensor = torch.tensor(img_m).permute(2, 0, 1).unsqueeze(0).to(device)
439443
return img_r_tensor, img_m_tensor
440444

441445

0 commit comments

Comments
 (0)