-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
Use MAD (Mean Absolute Deviation) pooling, it's more accurate than mean pooling with SSIM metric (I, honestly, would trust PSNR more than SSIM with mean pooling).
This is what I'm using for evaluation (using >2 scale levels doesn't make much of a difference):
import sys
from PIL import Image
import numpy as np
from scipy.ndimage import gaussian_filter
WEIGHTS = [0.0448]#, 0.2856, 0.3001, 0.2363, 0.1333]
def msssim(file1, file2):
img1 = Image.open(file1).convert('RGB')
img2 = Image.open(file2).convert('RGB')
width, height = img1.size
img1 = np.frombuffer(img1.tobytes(), dtype=np.uint8).reshape(height, width, 3) / 255
img2 = np.frombuffer(img2.tobytes(), dtype=np.uint8).reshape(height, width, 3) / 255
img1 = np.where(img1 > 0.04045, np.power((img1 + 0.055) / 1.055, 2.4), img1 / 12.92)
img2 = np.where(img2 > 0.04045, np.power((img2 + 0.055) / 1.055, 2.4), img2 / 12.92)
img1 = 0.2126 * img1[:,:,0] + 0.7152 * img1[:,:,1] + 0.0722 * img1[:,:,2]
img2 = 0.2126 * img2[:,:,0] + 0.7152 * img2[:,:,1] + 0.0722 * img2[:,:,2]
mssim = []
for i in range(len(WEIGHTS)):
mssim.append(ssim(pow(img1,1./2.2), pow(img2,1./2.2), i, i<len(WEIGHTS)-1))
img1 = gaussian_filter(img1, 1.08, truncate=1.5)[::2,::2]
img2 = gaussian_filter(img2, 1.08, truncate=1.5)[::2,::2]
return np.sum(np.multiply(np.stack(mssim), WEIGHTS)) / np.sum(WEIGHTS)
def mad(x, l):
return np.mean(np.absolute(x - np.power(np.mean(x), np.power(.5, l)))) # np.mean(np.absolute(x - np.mean(x if l==0 else np.sort(x, axis=None)[-int(x.size//1.5):])))
def ssim(L1, L2, lvl, cs_map):
C1=(0.01)**2
C2=(0.03)**2
sd, t = 1.5, 3 #kernel radius = round(sd * truncate)
mu1 = gaussian_filter(L1, sd, truncate=t)
mu2 = gaussian_filter(L2, sd, truncate=t)
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = gaussian_filter(L1 * L1, sd, truncate=t) - mu1_sq
sigma2_sq = gaussian_filter(L2 * L2, sd, truncate=t) - mu2_sq
sigma12 = gaussian_filter(L1 * L2, sd, truncate=t) - mu1_mu2
if cs_map:
value = (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2)
else:
value = ((2.0*mu1_mu2 + C1)*(2.0*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
(sigma1_sq + sigma2_sq + C2))
return mad(value, lvl)
def main():
for arg in sys.argv[2:]:
score = msssim(sys.argv[1], arg)
print(str(score) + "\t" + arg)
if __name__ == '__main__':
main()
Metadata
Metadata
Assignees
Labels
No labels