diff --git a/spharpy/__init__.py b/spharpy/__init__.py index e5360d33..0b91f37d 100644 --- a/spharpy/__init__.py +++ b/spharpy/__init__.py @@ -24,6 +24,7 @@ from . import interpolate from . import spatial from . import special +from . import sht __all__ = [ @@ -41,4 +42,5 @@ 'spatial', 'special', 'SamplingSphere', + 'sht' ] diff --git a/spharpy/sht.py b/spharpy/sht.py new file mode 100644 index 00000000..41466a20 --- /dev/null +++ b/spharpy/sht.py @@ -0,0 +1,174 @@ +import numpy as np +from pyfar import Signal, TimeData, FrequencyData +from pyfar import matrix_multiplication +from . import SphericalHarmonics +from . import SphericalHarmonicDefinition +from . import SphericalHarmonicSignal +from . import SphericalHarmonicTimeData +from . import SphericalHarmonicFrequencyData + + +def sht(signal, spherical_harmonics, axis='auto'): + """Compute the spherical harmonic transform + + Parameters + ---------- + signal : Signal, TimeData, or FrequencyData + the signal for which the spherical harmonic transform is computed + spherical_harmonics : :class:`spharpy.SphericalHarmonics` + Spherical harmonics object + axis : integer or 'auto' + Axis along which the spherical harmonic transform is computed. If 'auto' the + transformation is computed along the axis which matches the number + of spherical samples of the spherical_harmonics basis + + Returns + ---------- + sh_signal : SphericalHarmonicSignal, SphericalHarmonicsTimeData, + or SphericalHarmonicsFrequencyData + signal with spherical harmonic coefficients. According to + SphericalHarmonicsAudio definitions, the spherical harmonic + coefficients are always in the second to last dimension. The + order of all other channels remains unchanged. + References + ---------- + + [#] Rafaely, B. (2015). Fundamentals of Spherical Array Processing, + (J. Benesty and W. Kellermann, Eds.) Springer Berlin Heidelberg, + 2nd ed., 196 pages. doi:10.1007/978-3-319-99561-8 + [#] Ramani Duraiswami, Dimitry N. Zotkin, and Nail A. Gumerov: "Inter- + polation and range extrapolation of HRTFs." IEEE Int. Conf. + Acoustics, Speech, and Signal Processing (ICASSP), Montreal, + Canada, May 2004, p. 45-48, doi: 10.1109/ICASSP.2004.1326759. + """ + if isinstance(signal, (Signal, TimeData)): + data = signal.time + elif isinstance(signal, FrequencyData): + data = signal.freq + else: + raise ValueError("Input signal must be a Signal, TimeData, or " + f"FrequencyData but is {type(signal)}") + + if not isinstance(spherical_harmonics, SphericalHarmonics): + raise ValueError("spherical_harmonics must be SphericalHarmonics " + f"but is {type(spherical_harmonics)}") + + Y_inv = spherical_harmonics.basis_inv + if axis == 'auto': + axis = np.where(np.array(signal.cshape) == Y_inv.shape[1])[0] + if len(axis) == 0: + raise ValueError("No axes matches the number of spherical " + "harmonics basis functions") + if len(axis) > 1: + raise ValueError("Too many axis match the number of spherical " + "harmonics basis functions") + axis = axis[0] + + if signal.cshape[axis] != Y_inv.shape[1]: + raise ValueError("Spherical samples of provided axis does not match " + "the number of spherical harmonics basis functions.") + + # move spherical samples to -2 + data = np.moveaxis(data, axis, -2) + + # perform transform + data_nm = matrix_multiplication((Y_inv, data)) + + # ensure result has 3 dimensions + if len(data_nm.shape) < 3: + data_nm = data_nm[np.newaxis, ...] + + # set up SH definition + shd = SphericalHarmonicDefinition( + n_max=int(np.sqrt(data_nm.shape[-2])-1), + basis_type=spherical_harmonics.basis_type, + normalization=spherical_harmonics.normalization, + channel_convention=spherical_harmonics.channel_convention, + condon_shortley=spherical_harmonics.condon_shortley) + + if isinstance(signal, Signal): + sh_signal = SphericalHarmonicSignal.from_definition( + sh_definition=shd, + data=data_nm, + sampling_rate=signal.sampling_rate, + fft_norm=signal.fft_norm, + is_complex=signal.complex, + comment=signal.comment) + elif isinstance(signal, TimeData): + sh_signal = SphericalHarmonicTimeData.from_definition( + sh_definition=shd, + data=data_nm, + times=signal.times, + comment=signal.comment, + is_complex=False) + elif isinstance(signal, FrequencyData): + sh_signal = SphericalHarmonicFrequencyData.from_definition( + sh_definition=shd, + data=data_nm, + frequencies=signal.frequencies, + comment=signal.comment) + + return sh_signal + + +def isht(sh_signal, coordinates): + """Compute the inverse spherical harmonic transform + + Parameters + ---------- + sh_signal: SphericalHarmonicsSignal, SphericalHarmonicsTimeData, or + SphericalHarmonicsFrequencyData + The spherical harmonic signal for which the inverse spherical + harmonic transform is computed. + coordinates: :class:`spharpy.samplings.Coordinates`, :doc:`pf.Coordinates + ` + Coordinates for which the inverse SH transform is computed + + Returns + ---------- + signal : Signal, TimeData, or FrequencyData + inverse transformed signal in space domain. The spherical + samples are always in the second to last dimension. All other + channels remain unchaged. + + """ + if isinstance(sh_signal, (SphericalHarmonicSignal, + SphericalHarmonicTimeData)): + data = sh_signal.time + elif isinstance(sh_signal, SphericalHarmonicFrequencyData): + data = sh_signal.freq + else: + raise ValueError("Input signal has to be SphericalHarmonicSignal, " + "SphericalHarmonicTimeData, or " + "SphericalHarmonicFrequencyData " + f"but is {type(sh_signal)}") + + # get spherical harmonics basis functions according to sh_signals + # properties + spherical_harmonics = SphericalHarmonics( + sh_signal.n_max, + coordinates=coordinates, + basis_type=sh_signal.basis_type, + channel_convention=sh_signal.channel_convention, + normalization=sh_signal.normalization, + inverse_method="pseudo_inverse", + condon_shortley=sh_signal.condon_shortley) + + data = matrix_multiplication((spherical_harmonics.basis, data)) + + if isinstance(sh_signal, SphericalHarmonicSignal): + signal = Signal(data, + sh_signal.sampling_rate, + fft_norm=sh_signal.fft_norm, + comment=sh_signal.comment, + is_complex=sh_signal.complex) + elif isinstance(sh_signal, SphericalHarmonicTimeData): + signal = TimeData(data=data, + times=sh_signal.times, + comment=sh_signal.comment, + is_complex=sh_signal.complex) + else: + signal = FrequencyData(data=data, + frequencies=sh_signal.frequencies, + comment=sh_signal.comment) + return signal diff --git a/tests/test_sht.py b/tests/test_sht.py new file mode 100644 index 00000000..c3d1a887 --- /dev/null +++ b/tests/test_sht.py @@ -0,0 +1,218 @@ +import numpy as np +import numpy.testing as npt +import pyfar as pf +from pytest import raises, warns, mark +from spharpy.sht import sht, isht +from spharpy.classes.sh import SphericalHarmonicDefinition, SphericalHarmonics +from spharpy import SphericalHarmonicSignal +from spharpy import SphericalHarmonicTimeData +from spharpy import SphericalHarmonicFrequencyData +from spharpy import samplings + + +def test_sht_input_parameter(): + input_signal = np.zeros((3, 12, 2)) + n_max = 2 + sampling = samplings.equiangular(n_max=n_max) + sh = SphericalHarmonics(n_max=n_max, coordinates=sampling) + with raises(ValueError, match="Input signal must be a Signal, TimeData, " + f"or FrequencyData but is {type(input_signal)}"): + _ = sht(signal=input_signal, spherical_harmonics=sh) + + # test if SH in SphericalHarmonics object + + +def test_sht_output_parameter(): + n_max = 1 + sampling = samplings.equiangular(n_max=n_max) + sh = SphericalHarmonics(n_max=n_max, coordinates=sampling) + + # test Signal + signal = pf.Signal(data=np.zeros((1, 16, 4)), sampling_rate=48000) + test = sht(signal=signal, spherical_harmonics=sh) + assert isinstance(test, SphericalHarmonicSignal) + + # test TimeData + signal = pf.TimeData(data=np.zeros((1, 16, 4)), + times=[1, 2, 3, 4]) + test = sht(signal=signal, spherical_harmonics=sh) + assert isinstance(test, SphericalHarmonicTimeData) + + # test FrequencyData + signal = pf.FrequencyData(data=np.zeros((1, 16, 4)), + frequencies=[1, 2, 3, 4]) + test = sht(signal=signal, spherical_harmonics=sh) + assert isinstance(test, SphericalHarmonicFrequencyData) + + +def test_sht_assert_num_channels(): + "test assert match of number of channels and number of sampling positions" + n_max = 3 + signal = pf.Signal(data=np.zeros((7, 512)), sampling_rate=48000) + sampling = samplings.equiangular(n_max=n_max) + sh = SphericalHarmonics(n_max=n_max, coordinates=sampling) + + with raises(ValueError, match="Spherical samples of provided axis does " + "not match the number of spherical " + "harmonics basis functions."): + _ = sht(signal, sh, axis=0) + + +def test_isht_input_parameter(): + n_max = 1 + data = np.zeros((1, (n_max+1) ** 2, 16)) + sampling = samplings.gaussian(n_max=n_max) + with raises(ValueError, + match="Input signal has to be SphericalHarmonicSignal, " + "SphericalHarmonicTimeData, or " + "SphericalHarmonicFrequencyData " + f"but is {type(data)}"): + _ = isht(sh_signal=data, coordinates=sampling) + + +def test_isht_output_parameter(): + n_max = 1 + data = np.zeros((1, (n_max+1) ** 2, 5)) + sampling = samplings.gaussian(n_max=n_max) + + # test Signal + a_nm = SphericalHarmonicSignal(data, + basis_type='real', + channel_convention='ACN', + condon_shortley=True, + normalization='N3D', + sampling_rate=48000) + test = isht(sh_signal=a_nm, coordinates=sampling) + assert isinstance(test, pf.Signal) + + # test TimeData + a_nm = SphericalHarmonicTimeData(data, + times=[1, 2, 3, 4, 5], + basis_type='real', + channel_convention='ACN', + condon_shortley=True, + normalization='N3D') + test = isht(sh_signal=a_nm, coordinates=sampling) + assert isinstance(test, pf.TimeData) + + # test FrequencyData + a_nm = SphericalHarmonicFrequencyData( + data, + frequencies=[1, 2, 3, 4, 5], + basis_type='real', + channel_convention='ACN', + condon_shortley=True, + normalization='N3D') + + test = isht(sh_signal=a_nm, coordinates=sampling) + assert isinstance(test, pf.FrequencyData) + + +def test_sht_auto_axis(): + "test warning wrong axis" + n_max = 3 + signal = pf.Signal(data=np.zeros((7, 1, 32)), sampling_rate=48000) + sampling = samplings.equiangular(n_max=n_max) + sh = SphericalHarmonics(n_max=n_max, coordinates=sampling) + + with raises(ValueError, match="No axes matches the number of spherical " + "harmonics basis functions"): + _ = sht(signal, sh, axis='auto') + + signal = pf.Signal(data=np.zeros((64, 64, 64)), sampling_rate=48000) + with raises(ValueError, match="Too many axis match the number of " + "spherical harmonics basis functions"): + _ = sht(signal, sh, axis='auto') + + +def test_in_out_dimensions(): + n_max = 3 + n_samples = 128 + sampling = samplings.equiangular(n_max=n_max) + signal = pf.Signal(data=np.zeros((sampling.csize, n_samples)), + sampling_rate=48000) + sh = SphericalHarmonics(n_max=n_max, coordinates=sampling) + + sh_signal = sht(signal, sh, axis=0) + assert sh_signal.n_samples == n_samples + assert sh_signal.cshape[-1] == int(np.power(n_max+1, 2)) + assert sh_signal.cshape[0] == 1 + + signal = pf.Signal(data=np.zeros((1, sampling.csize, n_samples)), + sampling_rate=48000) + sh = SphericalHarmonics(n_max=n_max, coordinates=sampling) + + sh_signal = sht(signal, sh, axis=1) + assert sh_signal.n_samples == n_samples + assert sh_signal.cshape[-1] == int(np.power(n_max+1, 2)) + assert sh_signal.cshape[0] == 1 + + signal = pf.Signal(data=np.zeros((sampling.csize, 1, n_samples)), + sampling_rate=48000) + sh = SphericalHarmonics(n_max=n_max, coordinates=sampling) + + sh_signal = sht(signal, sh, axis=0) + assert sh_signal.n_samples == n_samples + assert sh_signal.cshape[-1] == int(np.power(n_max+1, 2)) + assert sh_signal.cshape[0] == 1 + + signal = pf.Signal(data=np.zeros((sampling.csize, 1, 2, 3, n_samples)), + sampling_rate=48000) + sh = SphericalHarmonics(n_max=n_max, coordinates=sampling) + + sh_signal = sht(signal, sh, axis=0) + assert sh_signal.n_samples == n_samples + assert sh_signal.cshape[-1] == int(np.power(n_max+1, 2)) + assert sh_signal.cshape[0] == 1 + assert sh_signal.cshape[1] == 2 + assert sh_signal.cshape[2] == 3 + + signal = pf.Signal(data=np.zeros((1, 2, sampling.csize, 3, n_samples)), + sampling_rate=48000) + sh = SphericalHarmonics(n_max=n_max, coordinates=sampling) + + sh_signal = sht(signal, sh, axis=2) + assert sh_signal.n_samples == n_samples + assert sh_signal.cshape[-1] == int(np.power(n_max+1, 2)) + assert sh_signal.cshape[0] == 1 + assert sh_signal.cshape[1] == 2 + assert sh_signal.cshape[2] == 3 + + +@mark.parametrize("n_max", [1, 3, 12, 20]) +@mark.parametrize("basis_type", ["real", "complex"]) +@mark.parametrize("normalization", ["N3D", "SN3D"]) +@mark.parametrize("condon_shortley", [True, False]) +def test_back_and_forth(n_max, basis_type, normalization, condon_shortley): + + sampling = samplings.gaussian(n_max=n_max) + # create unit amplitude SH coefficients + data = np.zeros((1, (n_max+1) ** 2, 16), dtype=complex) + if normalization == 'N3D': + data[0, 0, :] = np.sqrt(4 * np.pi) + else: + data[0, 0, :] = 1.0 + + is_complex = True + if basis_type == 'real': + data = np.real(data) + is_complex = False + + # generate unit amplitude sh signal + a_nm = SphericalHarmonicSignal(data, + basis_type=basis_type, + channel_convention='ACN', + condon_shortley=condon_shortley, + normalization=normalization, + sampling_rate=48000, + is_complex=is_complex) + a = isht(a_nm, sampling) + assert a_nm.n_samples == a.n_samples + sh = SphericalHarmonics(n_max=n_max, + coordinates=sampling, + basis_type=basis_type, + normalization=normalization, + condon_shortley=condon_shortley) + a_eval_nm = sht(a, sh) + assert a_eval_nm.n_samples == a.n_samples + npt.assert_allclose(a_nm.time, a_eval_nm.time, rtol=1e-14, atol=1e-14) diff --git a/tests/test_spherical_harmonics_signal.py b/tests/test_spherical_harmonics_signal.py index 22c146e7..afaa5a7c 100644 --- a/tests/test_spherical_harmonics_signal.py +++ b/tests/test_spherical_harmonics_signal.py @@ -126,8 +126,8 @@ def test_init_wrong_basis_type(): [1., 2., 3.], [1., 2., 3.]]).reshape(1, 4, 3) with pytest.raises(ValueError, - match="Invalid basis type, only " - "'complex' and 'real' are supported"): + match="Invalid basis type, only " + "'complex' and 'real' are supported"): SphericalHarmonicSignal(data, 44100, basis_type='invalid_basis_type', channel_convention='ACN',