Skip to content

Support using lower floating precisions for arguments in precompute transforms #320

@matt-graham

Description

@matt-graham

Currently the data type of the output arrays of the JAX forward and inverse spherical tranforms in s2fft.precompute_transforms are fixed to be of type complex128 irrespective of the type of the signal / coefficient and kernel array arguments to the functions.

While I think using double precision is important for numerical stability when computing the recursions as part of constructing the Wigner-d kernel, once the kernels are computed the operations in the forward and inverse transforms then just largely correspond to FFTs plus an einsum, both of which I think should be able to be computed in single precision stably. This means we should be able to for example compute the kernel in double precision and then cast to single precision, and use this with single precision signal / coefficient arrays while maintaining round-trip errors close to floating-point precision.

Some limited evidence in favour of this, after changing occurences in the code paths involved in precompute transforms in package from explicitly allocating arrays as complex128 to instead inherit data type from argument, running the following script

import jax
import numpy as np
import s2fft
import s2fft.precompute_transforms

jax.config.update("jax_enable_x64", True)

L = 128
spin = 0
sampling = "mw"
reality = True
recursion = "auto"
method = "jax"

rng = np.random.default_rng(1234)
flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality)
f = s2fft.inverse(flm, L=L, spin=spin, sampling=sampling, reality=reality, method=method)
kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method]
kernel = kernel_function(
    L=L, spin=spin, reality=reality, sampling=sampling, forward=True, recursion=recursion
)

flm_recovered_f64 = s2fft.precompute_transforms.spherical.forward(
    f=f,
    L=L,
    spin=spin,
    kernel=kernel,
    sampling=sampling,
    reality=reality,
    method=method
)
round_trip_error_f64 = abs(flm - flm_recovered_f64).max()

flm_recovered_f32 = s2fft.precompute_transforms.spherical.forward(
    f=f.astype("float32"),
    L=L, spin=spin,
    kernel=kernel.astype("float32"),
    sampling=sampling, reality=reality,
    method=method
)

assert flm_recovered_f32.dtype == "complex64"

round_trip_error_f32 = abs(flm - flm_recovered_f32).max()

print(f"Round-trip error: float64 = {round_trip_error_f64:.2e}, float32 = {round_trip_error_f32:.2e}")

gives output

Round-trip error: float64 = 1.43e-12, float32 = 2.31e-06

which suggests the precompute transforms can be computed with acceptable round-trip error at single floating precision if kernel is first computed at double precision.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions