Skip to content

Commit 664732d

Browse files
authored
add SSIM to metrax (#84)
* add SSIM metric to metrax * modify image_metrics_test class name * move private functions into SSIM
1 parent 0ca5403 commit 664732d

File tree

7 files changed

+569
-1
lines changed

7 files changed

+569
-1
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ keras-nlp
66
keras-rs
77
pytest
88
rouge-score
9-
scikit-learn
9+
scikit-learn
10+
tensorflow

src/metrax/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from metrax import base
1616
from metrax import classification_metrics
17+
from metrax import image_metrics
1718
from metrax import nlp_metrics
1819
from metrax import ranking_metrics
1920
from metrax import regression_metrics
@@ -38,6 +39,7 @@
3839
RecallAtK = ranking_metrics.RecallAtK
3940
RougeL = nlp_metrics.RougeL
4041
RougeN = nlp_metrics.RougeN
42+
SSIM = image_metrics.SSIM
4143
WER = nlp_metrics.WER
4244

4345

@@ -62,5 +64,6 @@
6264
"RecallAtK",
6365
"RougeL",
6466
"RougeN",
67+
"SSIM",
6568
"WER",
6669
]

src/metrax/image_metrics.py

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A collection of different metrics for image models."""
16+
17+
import flax
18+
import jax
19+
from jax import lax
20+
import jax.numpy as jnp
21+
from metrax import base
22+
23+
24+
def _gaussian_kernel1d(sigma, radius):
25+
r"""Generates a 1D normalized Gaussian kernel.
26+
27+
This function creates a 1D Gaussian kernel, which can be used for smoothing
28+
operations. The kernel is centered at zero and its values are determined by
29+
the Gaussian function:
30+
31+
.. math::
32+
\phi(x) = e^{-\frac{x^2}{2\sigma^2}}
33+
34+
The resulting kernel :math:`\phi(x)` is then normalized by dividing each
35+
element by the sum of all elements, so that the sum of the kernel's elements
36+
is 1. This function assumes an order of 0 for the Gaussian derivative (i.e.,
37+
a standard smoothing kernel).
38+
39+
Args:
40+
sigma (float): The standard deviation (:math:`\sigma`) of the Gaussian
41+
distribution. This controls the "width" or "spread" of the kernel.
42+
radius (int): The radius of the kernel. The kernel will include points from
43+
:math:`-radius` to :math:`+radius`. The total size of the kernel will be
44+
:math:`2 \times radius + 1`.
45+
46+
Returns:
47+
jnp.ndarray: A 1D JAX array representing the normalized Gaussian kernel.
48+
"""
49+
sigma2 = sigma * sigma
50+
x = jnp.arange(-radius, radius + 1)
51+
phi_x = jnp.exp(-0.5 / sigma2 * x**2)
52+
phi_x = phi_x / phi_x.sum()
53+
return phi_x
54+
55+
56+
@flax.struct.dataclass
57+
class SSIM(base.Average):
58+
r"""SSIM (Structural Similarity Index Measure) Metric.
59+
60+
This class calculates the structural similarity between predicted and target
61+
images and averages it over a dataset. SSIM is a perception-based model that
62+
considers changes in structural information, luminance, and contrast.
63+
64+
The general SSIM formula considers three components: luminance (l),
65+
contrast (c), and structure (s):
66+
67+
.. math::
68+
SSIM(x, y) = [l(x, y)]^\alpha \cdot [c(x, y)]^\beta \cdot [s(x,
69+
y)]^\gamma
70+
71+
Where:
72+
- Luminance comparison:
73+
:math:`l(x, y) = \frac{2\mu_x\mu_y + c_1}{\mu_x^2 + \mu_y^2 + c_1}`
74+
- Contrast comparison:
75+
:math:`c(x, y) = \frac{2\sigma_x\sigma_y + c_2}{\sigma_x^2 +
76+
\sigma_y^2 + c_2}`
77+
- Structure comparison:
78+
:math:`s(x, y) = \frac{\sigma_{xy} + c_3}{\sigma_x\sigma_y + c_3}`
79+
80+
This implementation uses a common simplified form where :math:`\alpha =
81+
\beta = \gamma = 1` and :math:`c_3 = c_2 / 2`.
82+
83+
This leads to the combined formula:
84+
85+
.. math::
86+
SSIM(x, y) = \frac{(2\mu_x\mu_y + c_1)(2\sigma_{xy} + c_2)}{(\mu_x^2 +
87+
\mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}
88+
89+
In these formulas:
90+
- :math:`\mu_x` and :math:`\mu_y` are the local means of :math:`x` and
91+
:math:`y`.
92+
- :math:`\sigma_x^2` and :math:`\sigma_y^2` are the local variances of
93+
:math:`x` and :math:`y`.
94+
- :math:`\sigma_{xy}` is the local covariance of :math:`x` and
95+
:math:`y`.
96+
- :math:`c_1 = (K_1 L)^2` and :math:`c_2 = (K_2 L)^2` are stabilization
97+
constants,
98+
where :math:`L` is the dynamic range of pixel values, and :math:`K_1,
99+
K_2` are small constants (e.g., 0.01 and 0.03).
100+
"""
101+
102+
@staticmethod
103+
def _calculate_ssim(
104+
img1: jnp.ndarray,
105+
img2: jnp.ndarray,
106+
max_val: float,
107+
filter_size: int = 11,
108+
filter_sigma: float = 1.5,
109+
k1: float = 0.01,
110+
k2: float = 0.03,
111+
) -> jnp.ndarray:
112+
"""Computes SSIM (Structural Similarity Index Measure) values for a batch of images.
113+
114+
This function calculates the SSIM between two batches of images (`img1` and
115+
`img2`). If the images have multiple channels, SSIM is calculated for each
116+
channel independently, and then the mean SSIM across channels is returned.
117+
118+
Args:
119+
img1: The first batch of images, expected shape ``(batch, height, width,
120+
channels)``.
121+
img2: The second batch of images, expected shape ``(batch, height, width,
122+
channels)``.
123+
max_val: The dynamic range of the pixel values (e.g., 1.0 for images
124+
normalized to [0,1] or 255 for uint8 images).
125+
filter_size: The size of the Gaussian filter window used for calculating
126+
local statistics. Must be an odd integer.
127+
filter_sigma: The standard deviation of the Gaussian filter.
128+
k1: A small constant used in the SSIM formula to stabilize the luminance
129+
comparison.
130+
k2: A small constant used in the SSIM formula to stabilize the
131+
contrast/structure comparison.
132+
133+
Returns:
134+
A 1D JAX array of shape ``(batch,)`` containing the SSIM value for each
135+
image pair in the batch.
136+
"""
137+
if img1.shape != img2.shape:
138+
raise ValueError(
139+
f'Input images must have the same shape, but got {img1.shape} and'
140+
f' {img2.shape}'
141+
)
142+
if img1.ndim != 4: # (batch, H, W, C)
143+
raise ValueError(
144+
'Input images must be 4D tensors (batch, height, width, channels),'
145+
f' but got {img1.ndim}D'
146+
)
147+
if img1.shape[-3] < filter_size or img1.shape[-2] < filter_size:
148+
raise ValueError(
149+
f'Image dimensions ({img1.shape[-3]}x{img1.shape[-2]}) must be at'
150+
f' least filter_size x filter_size ({filter_size}x{filter_size}).'
151+
)
152+
153+
num_channels = img1.shape[-1]
154+
img1 = img1.astype(jnp.float32)
155+
img2 = img2.astype(jnp.float32)
156+
157+
gaussian_kernal_1d = _gaussian_kernel1d(
158+
filter_sigma, (filter_size - 1) // 2
159+
)
160+
gaussian_kernel_2d = jnp.outer(gaussian_kernal_1d, gaussian_kernal_1d)
161+
# Kernel for convolution: (H_k, W_k, C_in=1, C_out=1)
162+
kernel_conv = gaussian_kernel_2d[:, :, jnp.newaxis, jnp.newaxis]
163+
164+
c1 = (k1 * max_val) ** 2
165+
c2 = (k2 * max_val) ** 2
166+
167+
def _calculate_ssim_for_channel(x_ch, y_ch, conv_kernel, c1, c2):
168+
r"""Calculates the Structural Similarity Index (SSIM) for a single channel.
169+
170+
This function computes the SSIM between two single-channel image arrays
171+
(:math:`x_{ch}` and :math:`y_{ch}`) using a precomputed Gaussian kernel
172+
for local statistics. The SSIM metric quantifies image quality
173+
degradation based on perceived changes in structural information, also
174+
incorporating important perceptual phenomena like luminance and contrast
175+
masking.
176+
177+
The general SSIM formula considers three components: luminance (l),
178+
contrast (c), and structure (s):
179+
180+
.. math::
181+
SSIM(x, y) = [l(x, y)]^\alpha \cdot [c(x, y)]^\beta \cdot [s(x,
182+
y)]^\gamma
183+
184+
Where:
185+
- Luminance comparison:
186+
:math:`l(x, y) = \frac{2\mu_x\mu_y + c_1}{\mu_x^2 + \mu_y^2 + c_1}`
187+
- Contrast comparison:
188+
:math:`c(x, y) = \frac{2\sigma_x\sigma_y + c_2}{\sigma_x^2 +
189+
\sigma_y^2 + c_2}`
190+
- Structure comparison:
191+
:math:`s(x, y) = \frac{\sigma_{xy} + c_3}{\sigma_x\sigma_y + c_3}`
192+
193+
This implementation uses a common simplified form where :math:`\alpha =
194+
\beta = \gamma = 1` and :math:`c_3 = c_2 / 2`.
195+
196+
This leads to the combined formula:
197+
198+
.. math::
199+
SSIM(x, y) = \frac{(2\mu_x\mu_y + c_1)(2\sigma_{xy} + c_2)}{(\mu_x^2 +
200+
\mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}
201+
202+
In these formulas:
203+
- :math:`\mu_x` and :math:`\mu_y` are the local means of :math:`x` and
204+
:math:`y`.
205+
- :math:`\sigma_x^2` and :math:`\sigma_y^2` are the local variances of
206+
:math:`x` and :math:`y`.
207+
- :math:`\sigma_{xy}` is the local covariance of :math:`x` and
208+
:math:`y`.
209+
- :math:`c_1 = (K_1 L)^2` and :math:`c_2 = (K_2 L)^2` are stabilization
210+
constants,
211+
where :math:`L` is the dynamic range of pixel values, and :math:`K_1,
212+
K_2` are small constants (e.g., 0.01 and 0.03).
213+
214+
Args:
215+
x_ch (jnp.ndarray): The first input image channel. Expected shape is
216+
``(batch, Height, Width, 1)``.
217+
y_ch (jnp.ndarray): The second input image channel. Expected shape is
218+
``(batch, Height, Width, 1)``.
219+
conv_kernel (jnp.ndarray): The 2D Gaussian kernel, reshaped to 4D, used
220+
for calculating local windowed statistics (mean, variance,
221+
covariance). Expected shape is ``(Kernel_H, Kernel_W, 1, 1)``.
222+
c1 (float): Stabilization constant for the luminance and mean component,
223+
:math:`(K_1 L)^2`.
224+
c2 (float): Stabilization constant for the variance and covariance
225+
component, :math:`(K_2 L)^2`.
226+
227+
Returns:
228+
jnp.ndarray: A scalar JAX array (or an array of scalars if batch size >
229+
1)
230+
representing the mean SSIM value(s) for the input channel(s).
231+
"""
232+
# x_ch, y_ch are (batch, H, W, 1)
233+
dn = lax.conv_dimension_numbers(
234+
x_ch.shape, conv_kernel.shape, ('NHWC', 'HWIO', 'NHWC')
235+
)
236+
237+
mu_x = lax.conv_general_dilated(
238+
x_ch,
239+
conv_kernel,
240+
window_strides=(1, 1),
241+
padding='VALID',
242+
dimension_numbers=dn,
243+
)
244+
mu_y = lax.conv_general_dilated(
245+
y_ch,
246+
conv_kernel,
247+
window_strides=(1, 1),
248+
padding='VALID',
249+
dimension_numbers=dn,
250+
)
251+
252+
mu_x_sq = mu_x**2
253+
mu_y_sq = mu_y**2
254+
mu_x_mu_y = mu_x * mu_y
255+
256+
sigma_x_sq = (
257+
lax.conv_general_dilated(
258+
x_ch**2,
259+
conv_kernel,
260+
window_strides=(1, 1),
261+
padding='VALID',
262+
dimension_numbers=dn,
263+
)
264+
- mu_x_sq
265+
)
266+
sigma_y_sq = (
267+
lax.conv_general_dilated(
268+
y_ch**2,
269+
conv_kernel,
270+
window_strides=(1, 1),
271+
padding='VALID',
272+
dimension_numbers=dn,
273+
)
274+
- mu_y_sq
275+
)
276+
sigma_xy = (
277+
lax.conv_general_dilated(
278+
x_ch * y_ch,
279+
conv_kernel,
280+
window_strides=(1, 1),
281+
padding='VALID',
282+
dimension_numbers=dn,
283+
)
284+
- mu_x_mu_y
285+
)
286+
287+
numerator1 = 2 * mu_x_mu_y + c1
288+
numerator2 = 2 * sigma_xy + c2
289+
denominator1 = mu_x_sq + mu_y_sq + c1
290+
denominator2 = sigma_x_sq + sigma_y_sq + c2
291+
292+
ssim_map = (numerator1 * numerator2) / (denominator1 * denominator2)
293+
return jnp.mean(
294+
ssim_map, axis=(1, 2, 3)
295+
) # Mean over H, W, C (which is 1 here for the map)
296+
297+
ssim_per_channel_list = []
298+
for i in range(num_channels):
299+
img1_c = lax.dynamic_slice_in_dim(
300+
img1, i * 1, 1, axis=3
301+
) # (batch, H, W, 1)
302+
img2_c = lax.dynamic_slice_in_dim(
303+
img2, i * 1, 1, axis=3
304+
) # (batch, H, W, 1)
305+
306+
ssim_for_channel = _calculate_ssim_for_channel(
307+
img1_c, img2_c, kernel_conv, c1, c2
308+
)
309+
ssim_per_channel_list.append(ssim_for_channel)
310+
311+
ssim_scores_stacked = jnp.stack(
312+
ssim_per_channel_list, axis=-1
313+
) # (batch, num_channels)
314+
return jnp.mean(ssim_scores_stacked, axis=-1) # (batch,)
315+
316+
@classmethod
317+
def from_model_output( # type: ignore[override]
318+
cls,
319+
predictions: jax.Array, # Represents predicted images (y_pred)
320+
targets: jax.Array, # Represents ground truth images (y_true)
321+
max_val: float, # Dynamic range of pixel values
322+
filter_size: int = 11,
323+
filter_sigma: float = 1.5,
324+
k1: float = 0.01,
325+
k2: float = 0.03,
326+
) -> 'SSIM':
327+
"""Computes SSIM for a batch of images and creates an SSIM metric instance.
328+
329+
This method takes batches of predicted and target images, calculates their
330+
SSIM values, and then initializes an SSIM metric object suitable for
331+
aggregation across multiple batches.
332+
333+
Args:
334+
predictions: A JAX array of predicted images, with shape ``(batch,
335+
height, width, channels)``.
336+
targets: A JAX array of ground truth images, with shape ``(batch,
337+
height, width, channels)``.
338+
max_val: The maximum possible pixel value (dynamic range) of the images
339+
(e.g., 1.0 for float images in [0,1], 255 for uint8 images).
340+
filter_size: The size of the Gaussian filter window used in SSIM
341+
calculation (default is 11).
342+
filter_sigma: The standard deviation of the Gaussian filter (default is
343+
1.5).
344+
k1: SSIM stability constant for the luminance term (default is 0.01).
345+
k2: SSIM stability constant for the contrast/structure term (default is
346+
0.03).
347+
348+
Returns:
349+
An SSIM instance containing the SSIM values for the current batch,
350+
ready for averaging.
351+
"""
352+
# shape (batch_size,)
353+
batch_ssim_values = cls._calculate_ssim(
354+
predictions,
355+
targets,
356+
max_val=max_val,
357+
filter_size=filter_size,
358+
filter_sigma=filter_sigma,
359+
k1=k1,
360+
k2=k2,
361+
)
362+
return super().from_model_output(values=batch_ssim_values)

0 commit comments

Comments
 (0)