-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
Hello! I'd like to understand if I should expect sharp edges when working with jax-finufft and multi-GPU JAX.
I created a toy example on the CPU for discussion, while I am still getting setup using multi-GPU. Using nufft1 on sharded inputs forcing JAX CPU to use 4 devices results in a 3x slowdown. I'll note that I'm not commenting on whether or not I expect this behavior, I'd just like to understand better how to think about multi-device JAX + jax-finufft.
Setup:
import jax
import jax.sharding as jshard
from jax_finufft import nufft1
jax.config.update("jax_num_cpu_devices", 4)
mesh = jax.make_mesh((len(jax.devices()),), ("batch",))
sharding = jshard.NamedSharding(mesh, jshard.PartitionSpec("batch"))
M = 4 * 1000
N = 4 * 2000
x, c = 2 * np.pi * np.random.uniform(size=M), np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
x_sh, c_sh = jax.device_put((x, c), sharding)
f = jax.jit(lambda _x, _c: nufft1(N, _c, _x, eps=1e-6, iflag=1))Output:
%timeit f(x, c).block_until_ready()
449 μs ± 3.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit f(x_sh, c_sh).block_until_ready()
1.45 ms ± 239 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
jax.print_environment_info()
jax: 0.6.2
jaxlib: 0.6.2
numpy: 2.3.2
python: 3.11.4 (main, Jun 15 2023, 07:55:38) [Clang 14.0.3 (clang-1403.0.22.14.1)]
device info: cpu-4, 4 local devices"
process_count: 1
platform: uname_result(system='Darwin', node=..., release='24.6.0', version='Darwin Kernel Version 24.6.0: Mon Jul 14 11:30:51 PDT 2025; root:xnu-11417.140.69~1/RELEASE_ARM64_T8112', machine='arm64')
Metadata
Metadata
Assignees
Labels
No labels