Replies: 1 comment
-
My high level understanding is that the CUDA libraries are all provided by NVIDIA wheels (e.g. https://pypi.org/project/nvidia-cuda-runtime-cu12/), so those can be dynamically linked. Depending on your build system, I expect you could do the same, but I don't know that I can provide any specific suggestions for how exactly to do it.
This is an internal interface related to how XLA interacts with numpy. I don't expect you would need to use anything similar unless you're using XLA internals. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have several functions similar to those in
jaxlib/gpu/solver_kernels_ffi.cc
, which I use to define custom JAX primitives. My goal is to package them into a Python wheel, similar tojax-cuda12-plugin
.Upon inspecting the
.so
files injax-cuda12-plugin
(e.g.,_solver.so
), I noticed that they are not linked to any CUDA runtime libraries. However, with my limited coding and packaging experience, I ended up bundling multiple CUDA libraries when building a wheel for PyPI distribution.I’m curious about how
jax-cuda12-plugin
is packaged and whether I can follow a similar approach. Since I’ve never used Bazel before, I find it difficult to understand what’s going on in its build process.Additionally, I came across the line
tsl::ImportNumpy();
in the code - what does this function do?Any insights would be greatly appreciated!
Beta Was this translation helpful? Give feedback.
All reactions