Skip to content

Avoiding large constant closures in model #241

@bzh-bzh

Description

@bzh-bzh

Hi Joshua,

Is there a good way to pass in large constant arrays into the log-likelihood explicitly as a function argument, instead of implicitly as a constant that gets closed over by JAX?

For context, we're hoping to use jaxns for a gravitational wave search with relative stellar astrometry. For the real data, we have more than can fit on a single GPU, so I've been trying to use a sharded array with automatic multi-host/multi-device parallelism.
Since our log-likelihood is independent between stars, it'd just be a parallel evaluation, and then a reduce-sum at the end.
But right now, JAX doesn't support having sharded arrays as closed-over constants for multiple hosts, unfortunately.

I've tried to pass in the data as a parameterised prior variable, but it seems that those get replicated somewhere in the jaxns internals, causing a very large GPU memory allocation + OOM error when the sampler gets lowered.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions