Unexpected increase in runtime when using the device option in jax.jit() #26859
Unanswered
eduardolneto
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am trying to integrate a JAX built python package into another non-JAX python package that uses mpi4py for parallelization.
In more detail, I am using an optimisation framework that given a number of cpu cores will run the same ammount of slightly different inputs through a series of calculations. This is done using mpi4py and thus each core will process an input in parallel. Therefore, the output results are available at their respective cpu cores.
After this, I would like to run each of these results through another application which is built in JAX. And ideally would like to do so, by running this JAX application in each of the already separated cpu cores.
Searching for the best approach to do this, I have tried to set the respetive cpu core using the device argument of jax.jit and the mpi4jax communicator rank in the following way:
result=jax.jit(jax_function,device=jax.devices('cpu')[mpi4py.rank_world])(input)
The expected timing of the single core application is ~40s. When using 2 processes I obtain the same timings for each of the processes. However, for 3 or more cpu cores, the timings increase to 600s and above. Futhermore, in these cases (3-8 cores), the application seems to be using all 8 cores of the machine, even if I explicitely chose:
mpirun -n 3 python application.py
with
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=3'
The main idea using a small example with a diffrax application would be:
However, in this small example this idea seems to work for any number of processes (1-8). And the timings of every process are similar to the timing of a single core run (2.0s). I also tried with a third application using adiffrax JAX Vlasov solver which takes 30s to run in single core and there I see again the behaviour I observe for the general application.
So I would assume, the duration of a single core run is important for what I am observing.
Thus, my questions are:
-If anyone is able to tell the cause behind what I am observing?
-If there is a different way to accomplish the linking between JAX and the mpi communicator in these cases, which you could point me to.
Thank you in advance!
Beta Was this translation helpful? Give feedback.
All reactions