Replies: 4 comments 24 replies
-
|
Hi, 1a) It is possible to write a JaxOperator and write the implementation in normal python/numba. You 'just' need to use a
You can try to use this package that I experimented with https://github.com/mpi4jax/mpibackend4jax and let me know how it goes. It should allow you to still use MPI. |
Beta Was this translation helpful? Give feedback.
-
|
Some options for writing a wrapper for the numba oparator mentioned in point 1 above: I) use jax.pure_callback inside of shard_map
II) just run it outside of jit
III) run numba inside of jax directly |
Beta Was this translation helpful? Give feedback.
-
|
Here is my attempt, can someone please give me some feedback? I would really appreciate it! If it all checks out, I would be happy to write a tutorial for it or perhaps a "Jax local operator wrapper" class if there is actually a use case for it other than myself... P.S. please dont yell at me if it is really horrible... also sorry for the long upcoming discussion What I did: I went with suggestion (1a) by @PhilipVinc (
|
Beta Was this translation helpful? Give feedback.
-
what's the error? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I have a question about using NetKet with parallelisation enabled.
So far, I have used Numba based LocalOperators and have used MPI for parallelisation, purely CPU based. This is because I had to, for some reason, modify the get_conn_flattened_kernel function and that was easier to do in the Numba version... I saw now in the latest release that MPI is deprecated, and parallelisation is carried out in Jax.
My questions are, can I still use my Numba operators and make use of parallelisation, or do I have to rewrite my Numba operator in Jax? If I can still use the Numba based operators, are there any performance drawbacks? Also, does the Jax parallelisation run on CPU and GPU, or only GPU?
Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions