|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Tests for metrax image metrics.""" |
| 16 | + |
| 17 | +import os |
| 18 | + |
| 19 | +os.environ['KERAS_BACKEND'] = 'jax' |
| 20 | + |
| 21 | +from absl.testing import absltest |
| 22 | +from absl.testing import parameterized |
| 23 | +import jax.numpy as jnp |
| 24 | +import metrax |
| 25 | +import numpy as np |
| 26 | +import torch |
| 27 | +from torchmetrics.functional.audio import snr as tm_snr |
| 28 | + |
| 29 | +np.random.seed(42) |
| 30 | + |
| 31 | +# Simple 1D audio signal. |
| 32 | +AUDIO_SHAPE_1D = (1000,) |
| 33 | +AUDIO_TARGET_1D = np.sin( |
| 34 | + np.linspace(0, 2 * np.pi * 5, AUDIO_SHAPE_1D[0]) |
| 35 | +).astype(np.float32) |
| 36 | +AUDIO_PREDS_1D_NOISY = ( |
| 37 | + AUDIO_TARGET_1D + 0.1 * np.random.randn(*AUDIO_SHAPE_1D) |
| 38 | +).astype(np.float32) |
| 39 | +AUDIO_PREDS_1D_PERFECT = AUDIO_TARGET_1D |
| 40 | +# Multi-dimensional batch of signals |
| 41 | +AUDIO_SHAPE_2D = (4, 500) # This is likely the source of the 4 elements. |
| 42 | +AUDIO_TARGET_2D = (np.random.randn(*AUDIO_SHAPE_2D) * 5.0).astype(np.float32) |
| 43 | +AUDIO_PREDS_2D_NOISY = ( |
| 44 | + AUDIO_TARGET_2D + 0.5 * np.random.randn(*AUDIO_SHAPE_2D) |
| 45 | +).astype(np.float32) |
| 46 | +# Target and preds are all zeros. |
| 47 | +AUDIO_SHAPE_ZEROS = (100,) |
| 48 | +AUDIO_TARGET_ZEROS = np.zeros(AUDIO_SHAPE_ZEROS).astype(np.float32) |
| 49 | +AUDIO_PREDS_ZEROS = np.zeros(AUDIO_SHAPE_ZEROS).astype(np.float32) |
| 50 | + |
| 51 | + |
| 52 | +class AudioMetricsTest(parameterized.TestCase): |
| 53 | + |
| 54 | + @parameterized.named_parameters( |
| 55 | + ( |
| 56 | + 'snr_1d_noisy_false_zero_mean', |
| 57 | + AUDIO_TARGET_1D, |
| 58 | + AUDIO_PREDS_1D_NOISY, |
| 59 | + False, |
| 60 | + ), |
| 61 | + ( |
| 62 | + 'snr_1d_noisy_true_zero_mean', |
| 63 | + AUDIO_TARGET_1D, |
| 64 | + AUDIO_PREDS_1D_NOISY, |
| 65 | + True, |
| 66 | + ), |
| 67 | + ( |
| 68 | + 'snr_1d_perfect_false_zero_mean', |
| 69 | + AUDIO_TARGET_1D, |
| 70 | + AUDIO_PREDS_1D_PERFECT, |
| 71 | + False, |
| 72 | + ), |
| 73 | + ( |
| 74 | + 'snr_1d_perfect_true_zero_mean', |
| 75 | + AUDIO_TARGET_1D, |
| 76 | + AUDIO_PREDS_1D_PERFECT, |
| 77 | + True, |
| 78 | + ), |
| 79 | + ( |
| 80 | + 'snr_2d_noisy_false_zero_mean', |
| 81 | + AUDIO_TARGET_2D, |
| 82 | + AUDIO_PREDS_2D_NOISY, |
| 83 | + False, |
| 84 | + ), |
| 85 | + ( |
| 86 | + 'snr_2d_noisy_true_zero_mean', |
| 87 | + AUDIO_TARGET_2D, |
| 88 | + AUDIO_PREDS_2D_NOISY, |
| 89 | + True, |
| 90 | + ), |
| 91 | + ( |
| 92 | + 'snr_zeros_false_zero_mean', |
| 93 | + AUDIO_TARGET_ZEROS, |
| 94 | + AUDIO_PREDS_ZEROS, |
| 95 | + False, |
| 96 | + ), |
| 97 | + ('snr_zeros_true_zero_mean', AUDIO_TARGET_ZEROS, AUDIO_PREDS_ZEROS, True), |
| 98 | + ) |
| 99 | + def test_snr( |
| 100 | + self, |
| 101 | + target_np: np.ndarray, |
| 102 | + preds_np: np.ndarray, |
| 103 | + zero_mean: bool, |
| 104 | + ): |
| 105 | + """Tests metrax.SNR against torchmetrics.functional.audio.snr.""" |
| 106 | + metrax_snr_metric = metrax.SNR.from_model_output( |
| 107 | + predictions=jnp.array(preds_np), |
| 108 | + targets=jnp.array(target_np), |
| 109 | + zero_mean=zero_mean, |
| 110 | + ) |
| 111 | + metrax_snr_result = metrax_snr_metric.compute() |
| 112 | + |
| 113 | + torchmetrics_snr_result = ( |
| 114 | + tm_snr.signal_noise_ratio( |
| 115 | + preds=torch.from_numpy(preds_np), |
| 116 | + target=torch.from_numpy(target_np), |
| 117 | + zero_mean=zero_mean, |
| 118 | + ) |
| 119 | + .mean() |
| 120 | + .item() |
| 121 | + ) |
| 122 | + |
| 123 | + np.testing.assert_allclose( |
| 124 | + metrax_snr_result, |
| 125 | + torchmetrics_snr_result, |
| 126 | + rtol=1e-5, |
| 127 | + atol=1e-5, |
| 128 | + err_msg=( |
| 129 | + f'SNR mismatch for zero_mean={zero_mean}.\n' |
| 130 | + f'Metrax SNR: {metrax_snr_result:.8f} dB, ' |
| 131 | + f'Torchmetrics SNR: {torchmetrics_snr_result:.8f} dB' |
| 132 | + ), |
| 133 | + ) |
| 134 | + |
| 135 | + |
| 136 | +if __name__ == '__main__': |
| 137 | + absltest.main() |
0 commit comments