Where is TPU device classification defined in JAX? #26874
Unanswered
samixyzdev
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Where is TPU device classification defined in JAX?
Hi JAX team and contributors,
I'm looking to contribute to JAX by modifying how TPU devices are classified.
Currently, multi-slice TPU devices are labeled as
MegaScalePjRtDevice
, but I believe they should remainTpuDevice
.I've tried tracing the classification process and followed these steps:
MegaScalePjRtDevice
in the JAX codebase.make_tpu_client()
, which is called inxla_bridge.py
.make_tpu_client()
is insidejaxlib
, likely in a compiled.so
file.libtpu.so
but didn’t find it on my system.pjrt_plugin
, but I couldn't confirm its exact behavior.Question: Where exactly in the JAX codebase is the TPU device classification (
MegaScalePjRtDevice
vs.TpuDevice
) handled? I'd love to contribute a fix, but I need to understand where this happens.Any guidance is appreciated! Thanks in advance. 😊
Beta Was this translation helpful? Give feedback.
All reactions