Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/metrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
RougeN = nlp_metrics.RougeN
SNR = audio_metrics.SNR
SSIM = image_metrics.SSIM
TotalVariation = image_metrics.TotalVariation
WER = nlp_metrics.WER


Expand Down
66 changes: 66 additions & 0 deletions src/metrax/image_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,69 @@ def compute(self) -> jax.Array:
"""Returns the final Dice coefficient."""
epsilon = 1e-7
return (2.0 * self.intersection) / (self.sum_pred + self.sum_true + epsilon)

@flax.struct.dataclass
class TotalVariation(base.Average):
r"""Calculates and returns the Total Variation (TV) for one or more images.

The total variation is the sum of the absolute differences for neighboring
pixel-values in the input images. This measures how much noise is in the
images.

This implements the anisotropic 2-D version of the formula described here:

https://en.wikipedia.org/wiki/Total_variation_denoising
"""

@staticmethod
def _calculate_total_variation(
images: jax.Array,
) -> jax.Array:
"""Computes Total Variation values.

Args:
images: 4-D Array of shape ``(batch, height, width, channels)`` or
3-D Array of shape ``(height, width, channels)``.

Returns:
Total variation of 'images'.

If `images` was 4-D, return a 1-D float Array of shape `[batch]` with the
total variation for each image in the batch.
If `images` was 3-D, return a scalar float with the total variation for
that image.
"""
ndims = images.ndim
if ndims == 3: # (height, width, channels)
# Shift images by one pixel along the height and width.
pixel_dif1 = jnp.abs(images[1:, :, :] - images[:-1, :, :])
pixel_dif2 = jnp.abs(images[:, 1:, :] - images[:, :-1, :])
sum_axis = None
elif ndims == 4: # (batch, height, width, channels)
# Shift images by one pixel along the height and width.
pixel_dif1 = jnp.abs(images[:, 1:, :, :] - images[:, :-1, :, :])
pixel_dif2 = jnp.abs(images[:, :, 1:, :] - images[:, :, :-1, :])
sum_axis = [1, 2, 3]
else:
raise ValueError(
f'Input images must be either 3 or 4-dimensional, got {ndims} dimensions instead.'
)

return jnp.sum(pixel_dif1, axis=sum_axis) + jnp.sum(pixel_dif2, axis=sum_axis)


@classmethod
def from_model_output(
cls,
predictions: jax.Array
) -> 'TotalVariation':
"""Computes the Total Variation for a batch of images and creates a TotalVariation metric instance.

Args:
predictions: A JAX array of predicted images, with shape ``(batch, H, W, C)``.

Returns:
A ``TotalVariation`` instance containing per‑image total variation values.
"""
total_variation = cls._calculate_total_variation(predictions)
return super().from_model_output(values=total_variation)
77 changes: 77 additions & 0 deletions src/metrax/image_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,27 @@
DICE_ALL_ZEROS = (np.array([0, 0, 0, 0]), np.array([0, 0, 0, 0]))
DICE_NO_OVERLAP = (np.array([1, 1, 0, 0]), np.array([0, 0, 1, 1]))

# Test data for TotalVariation
# Case 1: Basic, float normalized [0,1], single channel (3D)
TV_IMG_SHAPE_1 = (16, 16, 1) # height, width, channels
TV_IMG_1 = np.random.rand(*TV_IMG_SHAPE_1).astype(np.float32)

# Case 2: Multi-channel (3), float normalized [0,1] (3D)
TV_IMG_SHAPE_2 = (32, 32, 3)
TV_IMG_2 = np.random.rand(*TV_IMG_SHAPE_2).astype(np.float32)

# Case 3: Batch of single channel images (4D)
TV_IMG_SHAPE_3 = (4, 16, 16, 1) # batch, height, width, channels
TV_IMG_3 = np.random.rand(*TV_IMG_SHAPE_3).astype(np.float32)

# Case 4: Batch of multi-channel images (4D)
TV_IMG_SHAPE_4 = (4, 32, 32, 3) # batch, height, width, channels
TV_IMG_4 = np.random.rand(*TV_IMG_SHAPE_4).astype(np.float32)

# Case 5: Constant image (should have zero variation)
TV_IMG_SHAPE_5 = (16, 16, 1)
TV_IMG_5 = np.ones(TV_IMG_SHAPE_5, dtype=np.float32)


class ImageMetricsTest(parameterized.TestCase):

Expand Down Expand Up @@ -553,6 +574,62 @@ def test_dice(self, y_true, y_pred):

np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5)

@parameterized.named_parameters(
(
'tv_single_channel_3d',
TV_IMG_1,
),
(
'tv_multichannel_3d',
TV_IMG_2,
),
(
'tv_batch_single_channel_4d',
TV_IMG_3,
),
(
'tv_batch_multichannel_4d',
TV_IMG_4,
),
(
'tv_constant_image',
TV_IMG_5,
),
)
def test_total_variation_against_tensorflow(
self,
images_np: np.ndarray,
) -> None:
"""Test that TotalVariation metric computes values close to tf.image.total_variation."""

# Calculate TV using Metrax
# convert to uniform [B, H, W, C] otherwise `for image in images_np` will be 2D
# if input is 3D
images_np = images_np if images_np.ndim == 4 else np.expand_dims(images_np, axis=0)
metric = None
for image in images_np:
update = metrax.TotalVariation.from_model_output(
predictions=jnp.array(image)
)
metric = update if metric is None else metric.merge(update)
metrax_tv = metric.compute()

# Calculate TV using TensorFlow
tf_tv = tf.image.total_variation(tf.convert_to_tensor(images_np))
tf_mean = tf.reduce_mean(tf_tv).numpy()

# For constant image, TV should be 0
if np.array_equal(images_np, TV_IMG_5):
np.testing.assert_allclose(metrax_tv, 0.0, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(tf_mean, 0.0, rtol=1e-5, atol=1e-5)
else:
np.testing.assert_allclose(
metrax_tv,
tf_mean,
rtol=1e-5,
atol=1e-5,
err_msg='Total Variation mismatch',
)

if __name__ == '__main__':
absltest.main()
9 changes: 8 additions & 1 deletion src/metrax/metrax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'the quick brown fox jumps over the lazy dog',
'hello beautiful world',
]
# For image_metrics.SSIM and image_metrics.PSNR.
# For image_metrics.SSIM, image_metrics.PSNR and image_metrics.TotalVariation.
IMG_SHAPE = (4, 32, 32, 3)
PRED_IMGS = np.random.rand(*IMG_SHAPE).astype(np.float32)
TARGET_IMGS = np.random.rand(*IMG_SHAPE).astype(np.float32)
Expand Down Expand Up @@ -214,6 +214,13 @@ class MetraxTest(parameterized.TestCase):
'zero_mean': False,
},
),
(
'total_variation',
metrax.TotalVariation,
{
'predictions': PRED_IMGS
}
)
)
def test_metrics_jittable(self, metric, kwargs):
"""Tests that jitted metrax metric yields the same result as non-jitted metric."""
Expand Down
1 change: 1 addition & 0 deletions src/metrax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
RougeN = nnx_metrics.RougeN
SNR = nnx_metrics.SNR
SSIM = nnx_metrics.SSIM
TotalVariation = nnx_metrics.TotalVariation
WER = nnx_metrics.WER


Expand Down
7 changes: 7 additions & 0 deletions src/metrax/nnx/nnx_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ def __init__(self):
super().__init__(metrax.SSIM)


class TotalVariation(NnxWrapper):
"""An NNX class for the Metrax metric TotalVariation."""

def __init__(self):
super().__init__(metrax.TotalVariation)


class WER(NnxWrapper):
"""An NNX class for the Metrax metric WER."""

Expand Down