2424from pyiqa .utils .color_util import to_y_channel
2525from pyiqa .matlab_utils import fspecial , SCFpyr_PyTorch , math_util , filter2
2626from pyiqa .utils .registry import ARCH_REGISTRY
27+ from .func_util import preprocess_rgb
2728
2829
2930def ssim (X ,
3031 Y ,
31- win ,
32+ win = None ,
3233 get_ssim_map = False ,
3334 get_cs = False ,
3435 get_weight = False ,
3536 downsample = False ,
3637 data_range = 1. ,
37- test_y_channel = True ,
38- color_space = 'yiq' ):
39-
40- data_range = 255
41- # Whether calculate on y channel of ycbcr
42- if test_y_channel and X .shape [1 ] == 3 :
43- X = to_y_channel (X , data_range , color_space )
44- Y = to_y_channel (Y , data_range , color_space )
45- else :
46- X = X * data_range
47- X = X - X .detach () + X .round ()
48- Y = Y * data_range
49- Y = Y - Y .detach () + Y .round ()
50-
38+ ):
39+ if win is None :
40+ win = fspecial (11 , 1.5 , X .shape [1 ]).to (X )
41+
5142 C1 = (0.01 * data_range )** 2
5243 C2 = (0.03 * data_range )** 2
5344
@@ -58,8 +49,6 @@ def ssim(X,
5849 X = F .avg_pool2d (X , kernel_size = f )
5950 Y = F .avg_pool2d (Y , kernel_size = f )
6051
61- win = win .to (X .device )
62-
6352 mu1 = filter2 (X , win , 'valid' )
6453 mu2 = filter2 (Y , win , 'valid' )
6554 mu1_sq = mu1 .pow (2 )
@@ -98,11 +87,11 @@ class SSIM(torch.nn.Module):
9887 def __init__ (self , channels = 3 , downsample = False , test_y_channel = True , color_space = 'yiq' , crop_border = 0. ):
9988
10089 super (SSIM , self ).__init__ ()
101- self .win = fspecial (11 , 1.5 , channels )
10290 self .downsample = downsample
10391 self .test_y_channel = test_y_channel
10492 self .color_space = color_space
10593 self .crop_border = crop_border
94+ self .data_range = 255
10695
10796 def forward (self , X , Y ):
10897 assert X .shape == Y .shape , f'Input { X .shape } and reference images should have the same shape'
@@ -111,14 +100,11 @@ def forward(self, X, Y):
111100 crop_border = self .crop_border
112101 X = X [..., crop_border :- crop_border , crop_border :- crop_border ]
113102 Y = Y [..., crop_border :- crop_border , crop_border :- crop_border ]
103+
104+ X = preprocess_rgb (X , self .test_y_channel , self .data_range , self .color_space )
105+ Y = preprocess_rgb (Y , self .test_y_channel , self .data_range , self .color_space )
114106
115- score = ssim (
116- X ,
117- Y ,
118- win = self .win ,
119- downsample = self .downsample ,
120- test_y_channel = self .test_y_channel ,
121- color_space = self .color_space )
107+ score = ssim (X , Y , data_range = self .data_range , downsample = self .downsample )
122108 return score
123109
124110
@@ -185,11 +171,11 @@ class MS_SSIM(torch.nn.Module):
185171
186172 def __init__ (self , channels = 3 , downsample = False , test_y_channel = True , is_prod = True , color_space = 'yiq' ):
187173 super (MS_SSIM , self ).__init__ ()
188- self .win = fspecial (11 , 1.5 , channels )
189174 self .downsample = downsample
190175 self .test_y_channel = test_y_channel
191176 self .color_space = color_space
192177 self .is_prod = is_prod
178+ self .data_range = 255
193179
194180 def forward (self , X , Y ):
195181 """Computation of MS-SSIM metric.
@@ -201,14 +187,16 @@ def forward(self, X, Y):
201187 """
202188 assert X .shape == Y .shape , 'Input and reference images should have the same shape, but got'
203189 f'{ X .shape } and { Y .shape } '
190+
191+ X = preprocess_rgb (X , self .test_y_channel , self .data_range , self .color_space )
192+ Y = preprocess_rgb (Y , self .test_y_channel , self .data_range , self .color_space )
193+
204194 score = ms_ssim (
205- X ,
206- Y ,
207- win = self .win ,
195+ X , Y ,
196+ data_range = self .data_range ,
208197 downsample = self .downsample ,
209- test_y_channel = self .test_y_channel ,
210- is_prod = self .is_prod ,
211- color_space = self .color_space )
198+ is_prod = self .is_prod
199+ )
212200 return score
213201
214202
0 commit comments