Skip to content

Commit 257f24e

Browse files
committed
fix: 🚑 fix ssim calculation in nrqm
1 parent d9af97d commit 257f24e

File tree

4 files changed

+44
-45
lines changed

4 files changed

+44
-45
lines changed

pyiqa/archs/func_util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,22 @@
22
import torch
33
import torch.nn.functional as F
44

5+
from pyiqa.utils.color_util import to_y_channel
56
from pyiqa.matlab_utils import fspecial, imfilter
67
from .arch_util import excact_padding_2d
78

89
EPS = torch.finfo(torch.float32).eps
910

1011

12+
def preprocess_rgb(x, test_y_channel, data_range=1., color_space='yiq'):
13+
if test_y_channel and x.shape[1] == 3:
14+
x = to_y_channel(x, data_range, color_space)
15+
else:
16+
x = x * data_range
17+
x = x - x.detach() + x.round()
18+
return x
19+
20+
1121
def extract_2d_patches(x, kernel, stride=1, dilation=1, padding='same'):
1222
"""
1323
Ref: https://stackoverflow.com/a/65886666

pyiqa/archs/nrqm_arch.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pyiqa.utils.download_util import load_file_from_url
2020
from pyiqa.matlab_utils import imresize, fspecial, SCFpyr_PyTorch, dct2d, im2col
2121
from pyiqa.archs.func_util import extract_2d_patches
22-
from pyiqa.archs.ssim_arch import SSIM
22+
from pyiqa.archs.ssim_arch import ssim as ssim_func
2323
from pyiqa.archs.arch_util import ExactPadding2d
2424
from pyiqa.archs.niqe_arch import NIQE
2525
from warnings import warn
@@ -211,7 +211,7 @@ def norm_sender_normalized(pyr, num_scale=2, num_bands=6, blksz=3, eps=1e-12):
211211
L, Q = torch.linalg.eigh(C_x)
212212
L_pos = L * (L > 0)
213213
L_pos_sum = L_pos.sum(dim=1, keepdim=True)
214-
L = L_pos * L.sum(dim=1, keepdim=True) / (L_pos_sum + (L_pos_sum == 0).float())
214+
L = L_pos * L.sum(dim=1, keepdim=True) / (L_pos_sum + (L_pos_sum == 0).to(L.dtype))
215215
C_x = Q @ torch.diag_embed(L) @ Q.transpose(1, 2)
216216

217217
o_c = current_band[:, border:-border, border:-border]
@@ -267,16 +267,16 @@ def global_gsm(img: Tensor):
267267

268268
# structure correlation between scales
269269
hp_band = pyr[0]
270-
ssim_func = SSIM(channels=1, test_y_channel=False)
271-
for sb in subbands:
272-
sb_tmp = imresize(sb, sizes=hp_band.shape[1:]).unsqueeze(1)
273-
tmp_ssim = ssim_func(sb_tmp, hp_band.unsqueeze(1))
274-
feat.append(tmp_ssim)
270+
for sb in lp_bands:
271+
curr_band = imresize(sb, sizes=hp_band.shape[1:]).unsqueeze(1)
272+
_, tmpscore = ssim_func(curr_band, hp_band.unsqueeze(1), get_cs=True, data_range=255)
273+
feat.append(tmpscore)
275274

276275
# structure correlation between orientations
277276
for i in range(num_bands):
278277
for j in range(i + 1, num_bands):
279-
feat.append(ssim_func(subbands[i].unsqueeze(1), subbands[j].unsqueeze(1)))
278+
_, tmpscore = ssim_func(subbands[i].unsqueeze(1), subbands[j].unsqueeze(1), get_cs=True, data_range=255)
279+
feat.append(tmpscore)
280280

281281
feat = torch.stack(feat, dim=1)
282282
return feat
@@ -316,7 +316,7 @@ def random_forest_regression(feat, ldau, rdau, threshold_value, pred_value, best
316316
best_attri[:, i])
317317
tmp_pred.append(tmp_result)
318318
pred.append(tmp_pred)
319-
pred = torch.Tensor(pred)
319+
pred = torch.tensor(pred)
320320
return pred.mean(dim=1, keepdim=True)
321321

322322

@@ -335,7 +335,8 @@ def nrqm(
335335

336336
# crop image
337337
b, c, h, w = img.shape
338-
img_pyr = get_guass_pyramid(img.float() / 255.)
338+
img = img.double()
339+
img_pyr = get_guass_pyramid(img / 255.)
339340

340341
# DCT features
341342
f1 = []
@@ -359,7 +360,7 @@ def nrqm(
359360
for feat, rf in zip([f1, f2, f3], rf_param):
360361
tmp_pred = random_forest_regression(feat, *rf)
361362
preds = torch.cat((preds, tmp_pred), dim=1)
362-
quality = preds @ torch.Tensor(linear_param)
363+
quality = preds @ torch.tensor(linear_param)
363364

364365
return quality.squeeze()
365366

@@ -396,7 +397,7 @@ def calculate_nrqm(img: torch.Tensor,
396397

397398
if test_y_channel and img.shape[1] == 3:
398399
img = to_y_channel(img, 255, color_space)
399-
400+
400401
if crop_border != 0:
401402
img = img[..., crop_border:-crop_border, crop_border:-crop_border]
402403

pyiqa/archs/ssim_arch.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,21 @@
2424
from pyiqa.utils.color_util import to_y_channel
2525
from pyiqa.matlab_utils import fspecial, SCFpyr_PyTorch, math_util, filter2
2626
from pyiqa.utils.registry import ARCH_REGISTRY
27+
from .func_util import preprocess_rgb
2728

2829

2930
def ssim(X,
3031
Y,
31-
win,
32+
win=None,
3233
get_ssim_map=False,
3334
get_cs=False,
3435
get_weight=False,
3536
downsample=False,
3637
data_range=1.,
37-
test_y_channel=True,
38-
color_space='yiq'):
39-
40-
data_range = 255
41-
# Whether calculate on y channel of ycbcr
42-
if test_y_channel and X.shape[1] == 3:
43-
X = to_y_channel(X, data_range, color_space)
44-
Y = to_y_channel(Y, data_range, color_space)
45-
else:
46-
X = X * data_range
47-
X = X - X.detach() + X.round()
48-
Y = Y * data_range
49-
Y = Y - Y.detach() + Y.round()
50-
38+
):
39+
if win is None:
40+
win = fspecial(11, 1.5, X.shape[1]).to(X)
41+
5142
C1 = (0.01 * data_range)**2
5243
C2 = (0.03 * data_range)**2
5344

@@ -58,8 +49,6 @@ def ssim(X,
5849
X = F.avg_pool2d(X, kernel_size=f)
5950
Y = F.avg_pool2d(Y, kernel_size=f)
6051

61-
win = win.to(X.device)
62-
6352
mu1 = filter2(X, win, 'valid')
6453
mu2 = filter2(Y, win, 'valid')
6554
mu1_sq = mu1.pow(2)
@@ -98,11 +87,11 @@ class SSIM(torch.nn.Module):
9887
def __init__(self, channels=3, downsample=False, test_y_channel=True, color_space='yiq', crop_border=0.):
9988

10089
super(SSIM, self).__init__()
101-
self.win = fspecial(11, 1.5, channels)
10290
self.downsample = downsample
10391
self.test_y_channel = test_y_channel
10492
self.color_space = color_space
10593
self.crop_border = crop_border
94+
self.data_range = 255
10695

10796
def forward(self, X, Y):
10897
assert X.shape == Y.shape, f'Input {X.shape} and reference images should have the same shape'
@@ -111,14 +100,11 @@ def forward(self, X, Y):
111100
crop_border = self.crop_border
112101
X = X[..., crop_border:-crop_border, crop_border:-crop_border]
113102
Y = Y[..., crop_border:-crop_border, crop_border:-crop_border]
103+
104+
X = preprocess_rgb(X, self.test_y_channel, self.data_range, self.color_space)
105+
Y = preprocess_rgb(Y, self.test_y_channel, self.data_range, self.color_space)
114106

115-
score = ssim(
116-
X,
117-
Y,
118-
win=self.win,
119-
downsample=self.downsample,
120-
test_y_channel=self.test_y_channel,
121-
color_space=self.color_space)
107+
score = ssim(X, Y, data_range=self.data_range, downsample=self.downsample)
122108
return score
123109

124110

@@ -185,11 +171,11 @@ class MS_SSIM(torch.nn.Module):
185171

186172
def __init__(self, channels=3, downsample=False, test_y_channel=True, is_prod=True, color_space='yiq'):
187173
super(MS_SSIM, self).__init__()
188-
self.win = fspecial(11, 1.5, channels)
189174
self.downsample = downsample
190175
self.test_y_channel = test_y_channel
191176
self.color_space = color_space
192177
self.is_prod = is_prod
178+
self.data_range = 255
193179

194180
def forward(self, X, Y):
195181
"""Computation of MS-SSIM metric.
@@ -201,14 +187,16 @@ def forward(self, X, Y):
201187
"""
202188
assert X.shape == Y.shape, 'Input and reference images should have the same shape, but got'
203189
f'{X.shape} and {Y.shape}'
190+
191+
X = preprocess_rgb(X, self.test_y_channel, self.data_range, self.color_space)
192+
Y = preprocess_rgb(Y, self.test_y_channel, self.data_range, self.color_space)
193+
204194
score = ms_ssim(
205-
X,
206-
Y,
207-
win=self.win,
195+
X, Y,
196+
data_range=self.data_range,
208197
downsample=self.downsample,
209-
test_y_channel=self.test_y_channel,
210-
is_prod=self.is_prod,
211-
color_space=self.color_space)
198+
is_prod=self.is_prod
199+
)
212200
return score
213201

214202

pyiqa/matlab_utils/scfpyr_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def build(self, im_batch):
6363

6464
assert im_batch.device == self.device, 'Devices invalid (pyr = {}, batch = {})'.format(
6565
self.device, im_batch.device)
66-
assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32'
66+
# assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32'
6767
assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]'
6868
assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image'
6969

0 commit comments

Comments
 (0)