diff --git a/pyrato/edc.py b/pyrato/edc.py index 9bc300d4..6fd5310f 100644 --- a/pyrato/edc.py +++ b/pyrato/edc.py @@ -263,6 +263,10 @@ def energy_decay_curve_truncation( >>> ax.legend() """ + # flatten to allow signals with cdim > 1 + shape = data.time.shape + data = data.flatten() + energy_data = dsp.preprocess_rir( data, is_energy=is_energy, @@ -311,7 +315,7 @@ def energy_decay_curve_truncation( / np.nanmax(energy_decay_curve, axis=-1, keepdims=True) edc = pf.TimeData( - energy_decay_curve, data.times, comment=data.comment) + energy_decay_curve.reshape(shape), data.times, comment=data.comment) if plot: ax = pf.plot.time(data, dB=True, label='RIR') @@ -403,6 +407,9 @@ def energy_decay_curve_lundeby( >>> ax.legend() """ + # flatten to allow signals with cdim > 1 + shape = data.time.shape + data = data.flatten() energy_data = dsp.preprocess_rir( data, @@ -458,7 +465,7 @@ def energy_decay_curve_lundeby( / np.nanmax(energy_decay_curve, axis=-1, keepdims=True) edc = pf.TimeData( - energy_decay_curve, data.times, comment=data.comment) + energy_decay_curve.reshape(shape), data.times, comment=data.comment) if plot: ax = pf.plot.time(data, dB=True, label='RIR') @@ -551,6 +558,10 @@ def energy_decay_curve_chu( >>> ax.legend() """ + # flatten to allow signals with cdim > 1 + shape = data.cshape + data = data.flatten() + energy_data = dsp.preprocess_rir( data, is_energy=is_energy, @@ -588,6 +599,8 @@ def energy_decay_curve_chu( trunc_levels = 10*np.log10((psnr)) - threshold edc = truncate_energy_decay_curve(edc, trunc_levels) + edc = edc.reshape(shape) + if plot: plt.figure(figsize=(15, 3)) pf.plot.use('light') @@ -693,6 +706,9 @@ def energy_decay_curve_chu_lundeby( >>> ax.legend() """ + # flatten to allow signals with cdim > 1 + shape = data.time.shape + data = data.flatten() energy_data = dsp.preprocess_rir( data, @@ -754,7 +770,7 @@ def energy_decay_curve_chu_lundeby( / np.nanmax(energy_decay_curve, axis=-1, keepdims=True) edc = pf.TimeData( - energy_decay_curve, data.times, comment=data.comment) + energy_decay_curve.reshape(shape), data.times, comment=data.comment) if plot: ax = pf.plot.time(data, dB=True, label='RIR') diff --git a/tests/test_edc.py b/tests/test_edc.py index 17f2cf19..7770ae9d 100644 --- a/tests/test_edc.py +++ b/tests/test_edc.py @@ -96,3 +96,30 @@ def test_edc_sabine(): 6.37107964e-04, 5.46555336e-04, ]) npt.assert_almost_equal(edc, truth) + + +@pytest.mark.parametrize( + "edc_function", + [ra.edc.energy_decay_curve_chu, + ra.edc.energy_decay_curve_lundeby, + ra.edc.energy_decay_curve_chu_lundeby, + ra.energy_decay_curve_lundeby], +) +def test_multidim_edc(edc_function): + """ + Test if edcs from multichannel signal are equal to corresponding single + channel edcs. + """ + rir = pf.signals.files.room_impulse_response() + rir_oct = pf.dsp.filter.fractional_octave_bands(rir, 1) + shape = rir_oct.time.shape + edc = edc_function(rir_oct, channel_independent=True) + + assert shape == edc.time.shape + + edc = edc.flatten() + rir_oct = rir_oct.flatten() + + for i in range(edc.cshape[0]): + baseline = edc_function(rir_oct[i]) + npt.assert_array_equal(edc[i].time, baseline.time)