Skip to content

jax-finufft + multi-GPU JAX #179

@michael-0brien

Description

@michael-0brien

Hello! I'd like to understand if I should expect sharp edges when working with jax-finufft and multi-GPU JAX.

I created a toy example on the CPU for discussion, while I am still getting setup using multi-GPU. Using nufft1 on sharded inputs forcing JAX CPU to use 4 devices results in a 3x slowdown. I'll note that I'm not commenting on whether or not I expect this behavior, I'd just like to understand better how to think about multi-device JAX + jax-finufft.

Setup:

import jax
import jax.sharding as jshard
from jax_finufft import nufft1

jax.config.update("jax_num_cpu_devices", 4)

mesh = jax.make_mesh((len(jax.devices()),), ("batch",))
sharding = jshard.NamedSharding(mesh, jshard.PartitionSpec("batch"))

M = 4 * 1000
N = 4 * 2000
x, c = 2 * np.pi * np.random.uniform(size=M), np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
x_sh, c_sh = jax.device_put((x, c), sharding)

f = jax.jit(lambda _x, _c: nufft1(N, _c, _x, eps=1e-6, iflag=1))

Output:

%timeit f(x, c).block_until_ready()
449 μs ± 3.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%timeit f(x_sh, c_sh).block_until_ready()
1.45 ms ± 239 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

jax.print_environment_info()

jax:    0.6.2
jaxlib: 0.6.2
numpy:  2.3.2
python: 3.11.4 (main, Jun 15 2023, 07:55:38) [Clang 14.0.3 (clang-1403.0.22.14.1)]
device info: cpu-4, 4 local devices"
process_count: 1
platform: uname_result(system='Darwin', node=..., release='24.6.0', version='Darwin Kernel Version 24.6.0: Mon Jul 14 11:30:51 PDT 2025; root:xnu-11417.140.69~1/RELEASE_ARM64_T8112', machine='arm64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions