-
Notifications
You must be signed in to change notification settings - Fork 13
Description
Currently the data type of the output arrays of the JAX forward and inverse spherical tranforms in s2fft.precompute_transforms are fixed to be of type complex128 irrespective of the type of the signal / coefficient and kernel array arguments to the functions.
While I think using double precision is important for numerical stability when computing the recursions as part of constructing the Wigner-d kernel, once the kernels are computed the operations in the forward and inverse transforms then just largely correspond to FFTs plus an einsum, both of which I think should be able to be computed in single precision stably. This means we should be able to for example compute the kernel in double precision and then cast to single precision, and use this with single precision signal / coefficient arrays while maintaining round-trip errors close to floating-point precision.
Some limited evidence in favour of this, after changing occurences in the code paths involved in precompute transforms in package from explicitly allocating arrays as complex128 to instead inherit data type from argument, running the following script
import jax
import numpy as np
import s2fft
import s2fft.precompute_transforms
jax.config.update("jax_enable_x64", True)
L = 128
spin = 0
sampling = "mw"
reality = True
recursion = "auto"
method = "jax"
rng = np.random.default_rng(1234)
flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality)
f = s2fft.inverse(flm, L=L, spin=spin, sampling=sampling, reality=reality, method=method)
kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method]
kernel = kernel_function(
L=L, spin=spin, reality=reality, sampling=sampling, forward=True, recursion=recursion
)
flm_recovered_f64 = s2fft.precompute_transforms.spherical.forward(
f=f,
L=L,
spin=spin,
kernel=kernel,
sampling=sampling,
reality=reality,
method=method
)
round_trip_error_f64 = abs(flm - flm_recovered_f64).max()
flm_recovered_f32 = s2fft.precompute_transforms.spherical.forward(
f=f.astype("float32"),
L=L, spin=spin,
kernel=kernel.astype("float32"),
sampling=sampling, reality=reality,
method=method
)
assert flm_recovered_f32.dtype == "complex64"
round_trip_error_f32 = abs(flm - flm_recovered_f32).max()
print(f"Round-trip error: float64 = {round_trip_error_f64:.2e}, float32 = {round_trip_error_f32:.2e}")gives output
Round-trip error: float64 = 1.43e-12, float32 = 2.31e-06
which suggests the precompute transforms can be computed with acceptable round-trip error at single floating precision if kernel is first computed at double precision.