Skip to content

Commit e095ed6

Browse files
committed
fix: 🐛 fix bug in cw-ssim and liqe
1 parent af6c6cc commit e095ed6

File tree

5 files changed

+13
-9
lines changed

5 files changed

+13
-9
lines changed

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ release:
2121
test:
2222
pytest tests/ -m calibration -v
2323
pytest tests/test_metric_general.py::test_forward -v
24+
pytest tests/test_metric_general.py::test_cpu_gpu_consistency -v
2425

25-
test_general:
26+
test_cs:
2627
pytest tests/test_metric_general.py::test_cpu_gpu_consistency -v
2728

28-
test_gradient:
29+
test_grad:
2930
pytest tests/test_metric_general.py::test_gradient_backward -v
3031

3132
test_dataset:

ResultsCalibra/calibration_summary.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ gmsd(ours),0.2203,0.0005,0.0004,0.1346,0.205
1616
ilniqe,113.4801,23.9968,19.975,22.4493,56.6721
1717
ilniqe(ours),115.6107,24.0661,19.7494,22.3251,54.7635
1818
laion_aes,3.6420,5.5836,5.0716,4.6458,3.0889
19-
laion_aes(ours),3.6992,5.5835,5.1028,4.6500,3.0948
19+
laion_aes(ours),3.7204,5.5917,5.0756,4.6551,3.0973
2020
lpips,0.7237,0.2572,0.0508,0.052,0.4253
2121
lpips(ours),0.7237,0.2572,0.0508,0.0521,0.4253
2222
mad,195.2796,80.8379,30.3918,84.3542,202.2371
@@ -46,7 +46,7 @@ pi(ours),11.9286,3.073,2.6357,2.7979,6.9546
4646
psnr,21.11,20.99,27.01,23.3,21.62
4747
psnr(ours),21.1136,20.9872,27.0139,23.3002,21.6186
4848
ssim,0.6993,0.9978,0.9989,0.9669,0.6519
49-
ssim(ours),0.6997,0.9978,0.9989,0.9671,0.6521
49+
ssim(ours),0.6997,0.9978,0.9989,0.9671,0.6522
5050
vif,0.0172,0.9891,0.9924,0.9103,0.1745
5151
vif(ours),0.0172,0.9891,0.9924,0.9104,0.175
5252
vsi,0.9139,0.962,0.9922,0.9571,0.9262

pyiqa/archs/liqe_arch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,21 @@ def forward(self, x):
9292

9393
if x.size(1) < self.num_patch:
9494
num_patch = x.size(1)
95+
self.num_patch = num_patch
9596
else:
9697
num_patch = self.num_patch
9798

9899
if self.training:
99100
sel = torch.randint(low=0, high=x.size(0), size=(num_patch, ))
100101
else:
101-
sel_step = x.size(1) // self.num_patch
102+
sel_step = max(1, x.size(1) // self.num_patch)
102103
sel = torch.zeros(num_patch)
103104
for i in range(num_patch):
104105
sel[i] = sel_step * i
105106
sel = sel.long()
106107

107108
x = x[:, sel, ...]
109+
x = x.reshape(bs, num_patch, x.shape[2], x.shape[3], x.shape[4])
108110

109111
text_features = self.clip_model.encode_text(self.joint_texts.to(x.device))
110112
text_features = text_features / text_features.norm(dim=1, keepdim=True)
@@ -130,5 +132,4 @@ def forward(self, x):
130132

131133
quality = 1 * logits_quality[:, 0] + 2 * logits_quality[:, 1] + 3 * logits_quality[:, 2] + \
132134
4 * logits_quality[:, 3] + 5 * logits_quality[:, 4]
133-
134135
return quality

pyiqa/archs/ssim_arch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ def cw_ssim(self, x, y, test_y_channel):
266266
w = fspecial(s - 7 + 1, s[0] / 4, 1).to(x.device)
267267
gb = int(self.guardb / (2**(self.level - 1)))
268268

269+
self.win7 = self.win7.to(x.dtype)
270+
269271
for i in range(self.ori):
270272

271273
band1 = cw_x[bandind][i]

tests/test_metric_general.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_match_official_with_given_cases(ref_img, dist_img, metric_name, device)
8888
@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
8989
@pytest.mark.parametrize(
9090
("metric_name"),
91-
[(k) for k in pyiqa.list_models() if k not in ['ahiq', 'fid', 'vsi', 'clipscore']]
91+
[(k) for k in pyiqa.list_models() if k not in ['ahiq', 'fid', 'vsi', 'clipscore', 'topiq_nr-face', 'tres', 'tres-koniq']]
9292
)
9393
def test_cpu_gpu_consistency(metric_name):
9494
"""Test if the metric results are consistent between CPU and GPU.
@@ -121,7 +121,7 @@ def test_cpu_gpu_consistency(metric_name):
121121

122122
@pytest.mark.parametrize(
123123
("metric_name"),
124-
[(k) for k in pyiqa.list_models() if k not in ['pi', 'nrqm', 'fid', 'mad', 'vsi', 'clipscore', 'entropy']]
124+
[(k) for k in pyiqa.list_models() if k not in ['pi', 'nrqm', 'fid', 'mad', 'vsi', 'clipscore', 'entropy', 'topiq_nr-face']]
125125
)
126126
def test_gradient_backward(metric_name, device):
127127
"""Test if the metric can be used in a gradient descent process.
@@ -156,7 +156,7 @@ def test_gradient_backward(metric_name, device):
156156

157157
@pytest.mark.parametrize(
158158
("metric_name"),
159-
[(k) for k in pyiqa.list_models() if k not in ['fid', 'clipscore']]
159+
[(k) for k in pyiqa.list_models() if k not in ['fid', 'clipscore', 'topiq_nr-face']]
160160
)
161161
def test_forward(metric_name, device):
162162
"""Test if the metric can be used in a gradient descent process.

0 commit comments

Comments
 (0)