Hi Joshua,
Is there a good way to pass in large constant arrays into the log-likelihood explicitly as a function argument, instead of implicitly as a constant that gets closed over by JAX?
For context, we're hoping to use jaxns for a gravitational wave search with relative stellar astrometry. For the real data, we have more than can fit on a single GPU, so I've been trying to use a sharded array with automatic multi-host/multi-device parallelism.
Since our log-likelihood is independent between stars, it'd just be a parallel evaluation, and then a reduce-sum at the end.
But right now, JAX doesn't support having sharded arrays as closed-over constants for multiple hosts, unfortunately.
I've tried to pass in the data as a parameterised prior variable, but it seems that those get replicated somewhere in the jaxns internals, causing a very large GPU memory allocation + OOM error when the sampler gets lowered.
Thanks!
Hi Joshua,
Is there a good way to pass in large constant arrays into the log-likelihood explicitly as a function argument, instead of implicitly as a constant that gets closed over by JAX?
For context, we're hoping to use
jaxnsfor a gravitational wave search with relative stellar astrometry. For the real data, we have more than can fit on a single GPU, so I've been trying to use a sharded array with automatic multi-host/multi-device parallelism.Since our log-likelihood is independent between stars, it'd just be a parallel evaluation, and then a reduce-sum at the end.
But right now, JAX doesn't support having sharded arrays as closed-over constants for multiple hosts, unfortunately.
I've tried to pass in the data as a parameterised prior variable, but it seems that those get replicated somewhere in the
jaxnsinternals, causing a very large GPU memory allocation + OOM error when the sampler gets lowered.Thanks!