Skip to content

Commit c53cfba

Browse files
committed
account for review comments
1 parent bd941f7 commit c53cfba

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

pyrato/parametric.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def energy_decay_curve(
1212
reverberation_time : float | np.ndarray,
1313
energy : float | np.ndarray = 1,
1414
) -> pf.TimeData:
15-
r"""Calculate the energy decay curve from the reverberation time and energy.
15+
r"""Calculate the energy decay curve for the reverberation time and energy.
1616
1717
The energy decay curve is calculated as
1818
@@ -27,9 +27,13 @@ def energy_decay_curve(
2727
times : numpy.ndarray[float]
2828
The times at which the energy decay curve is evaluated.
2929
reverberation_time : float | numpy.ndarray[float]
30-
The reverberation time in seconds.
30+
The reverberation time in seconds. The an array is passed, a energy
31+
decay curve is calculated for each reverberation time.
3132
energy : float | numpy.ndarray[float], optional
32-
The initial energy of the sound field, by default 1.
33+
The initial energy of the sound field, by default 1. If
34+
``reverberation_time`` is an array, the shape of ``energy`` is required
35+
to match the shape or be broadcastable to the shape of
36+
``reverberation_time``.
3337
3438
Returns
3539
-------
@@ -48,7 +52,7 @@ def energy_decay_curve(
4852
>>> import pyfar as pf
4953
>>>
5054
>>> times = np.linspace(0, 3, 50)
51-
>>> T_60 = 2
55+
>>> T_60 = [2, 1]
5256
>>> edc = pyrato.parametric.energy_decay_curve(times, T_60)
5357
>>> pf.plot.time(edc, log_prefix=10, dB=True)
5458
@@ -69,12 +73,12 @@ def energy_decay_curve(
6973
raise ValueError("Energy must be greater than or equal to zero.")
7074

7175
if reverberation_time.shape != energy.shape:
72-
shape = np.broadcast_shapes(energy.shape, reverberation_time.shape)
7376
try:
74-
energy = np.broadcast_to(energy, shape)
77+
energy = np.broadcast_to(energy, reverberation_time.shape)
7578
except ValueError as error:
7679
raise ValueError(
77-
"Reverberation time and energy must have the same shape.",
80+
"Reverberation time and energy must be broadcastable to the "
81+
"same shape.",
7882
) from error
7983

8084
matching_shape = reverberation_time.shape

tests/test_edc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from numpy import array
88

99
import numpy.testing as npt
10+
import pytest
1011

1112
import pyrato as ra
1213

@@ -70,3 +71,12 @@ def test_parametric_edc():
7071

7172
assert edc.cshape == (3, 2)
7273

74+
75+
def test_parametric_edc_wrong_shapes():
76+
"""Test error handling for wrong shapes."""
77+
times = np.linspace(0, 0.25, 50)
78+
T_60 = np.array([2, 1])
79+
energy = np.array([1, 1, 1])
80+
81+
with pytest.raises(ValueError, match="same shape."):
82+
ra.parametric.energy_decay_curve(times, T_60, energy=energy)

0 commit comments

Comments
 (0)