Replies: 1 comment
-
My high level understanding is that the CUDA libraries are all provided by NVIDIA wheels (e.g. https://pypi.org/project/nvidia-cuda-runtime-cu12/), so those can be dynamically linked. Depending on your build system, I expect you could do the same, but I don't know that I can provide any specific suggestions for how exactly to do it.
This is an internal interface related to how XLA interacts with numpy. I don't expect you would need to use anything similar unless you're using XLA internals. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have several functions similar to those in
jaxlib/gpu/solver_kernels_ffi.cc
, which I use to define custom JAX primitives. My goal is to package them into a Python wheel, similar tojax-cuda12-plugin
.Upon inspecting the
.so
files injax-cuda12-plugin
(e.g.,_solver.so
), I noticed that they are not linked to any CUDA runtime libraries. However, with my limited coding and packaging experience, I ended up bundling multiple CUDA libraries when building a wheel for PyPI distribution.I’m curious about how
jax-cuda12-plugin
is packaged and whether I can follow a similar approach. Since I’ve never used Bazel before, I find it difficult to understand what’s going on in its build process.Additionally, I came across the line
tsl::ImportNumpy();
in the code - what does this function do?Any insights would be greatly appreciated!
Beta Was this translation helpful? Give feedback.
All reactions