Skip to content

Commit d63cd21

Browse files
committed
add audio_metrics module and snr to metrax
1 parent 3c7597a commit d63cd21

File tree

7 files changed

+281
-2
lines changed

7 files changed

+281
-2
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ pytest
88
rouge-score
99
scikit-learn
1010
tensorflow
11+
torchmetrics

src/metrax/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from metrax import audio_metrics
1516
from metrax import base
1617
from metrax import classification_metrics
1718
from metrax import image_metrics
@@ -42,6 +43,7 @@
4243
RecallAtK = ranking_metrics.RecallAtK
4344
RougeL = nlp_metrics.RougeL
4445
RougeN = nlp_metrics.RougeN
46+
SNR = audio_metrics.SNR
4547
SSIM = image_metrics.SSIM
4648
WER = nlp_metrics.WER
4749

@@ -70,6 +72,7 @@
7072
"RecallAtK",
7173
"RougeL",
7274
"RougeN",
75+
"SNR",
7376
"SSIM",
7477
"WER",
7578
]

src/metrax/audio_metrics.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A collection of different metrics for audio models."""
16+
17+
import flax
18+
import jax
19+
import jax.numpy as jnp
20+
from metrax import base
21+
22+
23+
@flax.struct.dataclass
24+
class SNR(base.Average):
25+
r"""SNR (Signal-to-Noise Ratio) Metric for audio.
26+
27+
This class calculates the Signal-to-Noise Ratio (SNR) in decibels (dB)
28+
between a predicted audio signal and a ground truth audio signal,
29+
and averages it over a dataset.
30+
31+
The SNR is defined as:
32+
33+
.. math::
34+
35+
SNR_{dB} = 10 \\cdot \\log_{10} \\left( \\frac{P_{signal}}{P_{noise}}
36+
\\right)
37+
38+
Where:
39+
- :math:`P_{signal}` is the power of the ground truth signal (`target`).
40+
By default (`zero_mean=False`), this is the mean of the squared `target`
41+
values.
42+
If `zero_mean=True`, it's the variance of the `target` values.
43+
- :math:`P_{noise}` is the power of the noise component, which is defined as
44+
the difference between the `target` and `preds` (`target - preds`).
45+
By default (`zero_mean=False`), this is the mean of the squared noise
46+
values.
47+
If `zero_mean=True`, it's the variance of the noise values.
48+
"""
49+
50+
@staticmethod
51+
def _calculate_snr(
52+
preds: jax.Array,
53+
target: jax.Array,
54+
zero_mean: bool = False,
55+
) -> jax.Array:
56+
"""Computes SNR (Signal-to-Noise Ratio) values for a batch of audio signals.
57+
58+
Args:
59+
preds: The estimated or predicted audio signal. JAX Array.
60+
target: The ground truth audio signal. JAX Array.
61+
zero_mean: If True, subtracts the mean from the signal and noise before
62+
calculating their respective powers. Defaults to False.
63+
64+
Returns:
65+
A 1D JAX array representing the SNR in decibels (dB) for each example
66+
in the batch.
67+
"""
68+
if preds.shape != target.shape:
69+
raise ValueError(
70+
f'Input signals must have the same shape, but got {preds.shape} and'
71+
f' {target.shape}'
72+
)
73+
74+
target_processed, preds_processed = jax.lax.cond(
75+
zero_mean,
76+
lambda t, p: (
77+
t - jnp.mean(t, axis=-1, keepdims=True),
78+
p - jnp.mean(p, axis=-1, keepdims=True),
79+
),
80+
lambda t, p: (t, p),
81+
target,
82+
preds,
83+
)
84+
noise = target_processed - preds_processed
85+
eps = jnp.finfo(preds.dtype).eps
86+
signal_power = jnp.sum(target_processed**2, axis=-1) + eps
87+
noise_power = jnp.sum(noise**2, axis=-1) + eps
88+
89+
snr = 10 * jnp.log10(base.divide_no_nan(signal_power, noise_power))
90+
return snr
91+
92+
@classmethod
93+
def from_model_output(
94+
cls,
95+
predictions: jax.Array,
96+
targets: jax.Array,
97+
zero_mean: bool = False,
98+
) -> 'SNR':
99+
"""Computes SNR for a batch of audio signals and creates an SNR metric instance.
100+
101+
Args:
102+
predictions: A JAX array of predicted audio signals.
103+
targets: A JAX array of ground truth audio signals.
104+
zero_mean: If True, subtracts the mean from the signal and noise before
105+
calculating their respective powers.
106+
107+
Returns:
108+
An SNR instance containing the SNR value for the current batch,
109+
ready for averaging.
110+
"""
111+
batch_snr_value = cls._calculate_snr(
112+
predictions,
113+
targets,
114+
zero_mean=zero_mean,
115+
)
116+
return super().from_model_output(values=batch_snr_value)

src/metrax/audio_metrics_test.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for metrax image metrics."""
16+
17+
import os
18+
19+
os.environ['KERAS_BACKEND'] = 'jax'
20+
21+
from absl.testing import absltest
22+
from absl.testing import parameterized
23+
import jax.numpy as jnp
24+
import metrax
25+
import numpy as np
26+
import torch
27+
from torchmetrics.functional.audio import snr as tm_snr
28+
29+
np.random.seed(42)
30+
31+
# Simple 1D audio signal.
32+
AUDIO_SHAPE_1D = (1000,)
33+
AUDIO_TARGET_1D = np.sin(
34+
np.linspace(0, 2 * np.pi * 5, AUDIO_SHAPE_1D[0])
35+
).astype(np.float32)
36+
AUDIO_PREDS_1D_NOISY = (
37+
AUDIO_TARGET_1D + 0.1 * np.random.randn(*AUDIO_SHAPE_1D)
38+
).astype(np.float32)
39+
AUDIO_PREDS_1D_PERFECT = AUDIO_TARGET_1D
40+
# Multi-dimensional batch of signals
41+
AUDIO_SHAPE_2D = (4, 500) # This is likely the source of the 4 elements.
42+
AUDIO_TARGET_2D = (np.random.randn(*AUDIO_SHAPE_2D) * 5.0).astype(np.float32)
43+
AUDIO_PREDS_2D_NOISY = (
44+
AUDIO_TARGET_2D + 0.5 * np.random.randn(*AUDIO_SHAPE_2D)
45+
).astype(np.float32)
46+
# Target and preds are all zeros.
47+
AUDIO_SHAPE_ZEROS = (100,)
48+
AUDIO_TARGET_ZEROS = np.zeros(AUDIO_SHAPE_ZEROS).astype(np.float32)
49+
AUDIO_PREDS_ZEROS = np.zeros(AUDIO_SHAPE_ZEROS).astype(np.float32)
50+
51+
52+
class AudioMetricsTest(parameterized.TestCase):
53+
54+
@parameterized.named_parameters(
55+
(
56+
'snr_1d_noisy_false_zero_mean',
57+
AUDIO_TARGET_1D,
58+
AUDIO_PREDS_1D_NOISY,
59+
False,
60+
),
61+
(
62+
'snr_1d_noisy_true_zero_mean',
63+
AUDIO_TARGET_1D,
64+
AUDIO_PREDS_1D_NOISY,
65+
True,
66+
),
67+
(
68+
'snr_1d_perfect_false_zero_mean',
69+
AUDIO_TARGET_1D,
70+
AUDIO_PREDS_1D_PERFECT,
71+
False,
72+
),
73+
(
74+
'snr_1d_perfect_true_zero_mean',
75+
AUDIO_TARGET_1D,
76+
AUDIO_PREDS_1D_PERFECT,
77+
True,
78+
),
79+
(
80+
'snr_2d_noisy_false_zero_mean',
81+
AUDIO_TARGET_2D,
82+
AUDIO_PREDS_2D_NOISY,
83+
False,
84+
),
85+
(
86+
'snr_2d_noisy_true_zero_mean',
87+
AUDIO_TARGET_2D,
88+
AUDIO_PREDS_2D_NOISY,
89+
True,
90+
),
91+
(
92+
'snr_zeros_false_zero_mean',
93+
AUDIO_TARGET_ZEROS,
94+
AUDIO_PREDS_ZEROS,
95+
False,
96+
),
97+
('snr_zeros_true_zero_mean', AUDIO_TARGET_ZEROS, AUDIO_PREDS_ZEROS, True),
98+
)
99+
def test_snr(
100+
self,
101+
target_np: np.ndarray,
102+
preds_np: np.ndarray,
103+
zero_mean: bool,
104+
):
105+
"""Tests metrax.SNR against torchmetrics.functional.audio.snr."""
106+
metrax_snr_metric = metrax.SNR.from_model_output(
107+
predictions=jnp.array(preds_np),
108+
targets=jnp.array(target_np),
109+
zero_mean=zero_mean,
110+
)
111+
metrax_snr_result = metrax_snr_metric.compute()
112+
113+
torchmetrics_snr_result = (
114+
tm_snr.signal_noise_ratio(
115+
preds=torch.from_numpy(preds_np),
116+
target=torch.from_numpy(target_np),
117+
zero_mean=zero_mean,
118+
)
119+
.mean()
120+
.item()
121+
)
122+
123+
np.testing.assert_allclose(
124+
metrax_snr_result,
125+
torchmetrics_snr_result,
126+
rtol=1e-5,
127+
atol=1e-5,
128+
err_msg=(
129+
f'SNR mismatch for zero_mean={zero_mean}.\n'
130+
f'Metrax SNR: {metrax_snr_result:.8f} dB, '
131+
f'Torchmetrics SNR: {torchmetrics_snr_result:.8f} dB'
132+
),
133+
)
134+
135+
136+
if __name__ == '__main__':
137+
absltest.main()

src/metrax/metrax_test.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@
5555
np.int32
5656
)
5757
IOU_TARGET_CLASS_IDS = np.array([0, 1])
58-
58+
# For audio_metrics.
59+
AUDIO_SHAPE = (2, 16000)
60+
AUDIO_PREDS = np.random.randn(*AUDIO_SHAPE).astype(np.float32)
61+
AUDIO_TARGETS = np.random.randn(*AUDIO_SHAPE).astype(np.float32)
5962

6063
class MetraxTest(parameterized.TestCase):
6164

@@ -168,7 +171,7 @@ class MetraxTest(parameterized.TestCase):
168171
'targets': TARGET_IMGS,
169172
'max_val': MAX_IMG_VAL,
170173
},
171-
),
174+
),
172175
(
173176
'rmse',
174177
metrax.RMSE,
@@ -202,6 +205,15 @@ class MetraxTest(parameterized.TestCase):
202205
'max_val': MAX_IMG_VAL,
203206
},
204207
),
208+
(
209+
'snr',
210+
metrax.SNR,
211+
{
212+
'predictions': AUDIO_PREDS,
213+
'targets': AUDIO_TARGETS,
214+
'zero_mean': False,
215+
},
216+
),
205217
)
206218
def test_metrics_jittable(self, metric, kwargs):
207219
"""Tests that jitted metrax metric yields the same result as non-jitted metric."""

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
RecallAtK = nnx_metrics.RecallAtK
3838
RougeL = nnx_metrics.RougeL
3939
RougeN = nnx_metrics.RougeN
40+
SNR = nnx_metrics.SNR
4041
SSIM = nnx_metrics.SSIM
4142
WER = nnx_metrics.WER
4243

@@ -64,6 +65,7 @@
6465
"RecallAtK",
6566
"RougeL",
6667
"RougeN",
68+
"SNR",
6769
"SSIM",
6870
"WER",
6971
]

src/metrax/nnx/nnx_metrics.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Accuracy(NnxWrapper):
3838
def __init__(self):
3939
super().__init__(metrax.Accuracy)
4040

41+
4142
class Average(NnxWrapper):
4243
"""An NNX class for the Metrax metric Average."""
4344

@@ -178,6 +179,13 @@ def __init__(self):
178179
super().__init__(metrax.RSQUARED)
179180

180181

182+
class SNR(NnxWrapper):
183+
"""An NNX class for the Metrax metric SNR."""
184+
185+
def __init__(self):
186+
super().__init__(metrax.SNR)
187+
188+
181189
class SSIM(NnxWrapper):
182190
"""An NNX class for the Metrax metric SSIM."""
183191

0 commit comments

Comments
 (0)