Skip to content

Commit 4d25b21

Browse files
committed
Enable testCaiToJax in array interoperability test on ROCm platform
This change enables the tests/array_interoperability_test.py::CudaArrayInterfaceTest::testCaiToJax on ROCm GPUs: 1. jaxlib/py_array.cc: Add ROCm platform check alongside CUDA for __cuda_array_interface__ property support. 2. jax/_src/numpy/array_constructors.py: - Add ROCm plugin extension discovery 3. tests/array_interoperability_test.py: Change testCaiToJax to use @jtu.run_on_devices("gpu") to run on both CUDA and ROCm.
1 parent 191504b commit 4d25b21

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

jax/_src/numpy/array_constructors.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@
4444
else:
4545
break
4646

47+
# Dynamically find and load ROCm plugin extension
48+
rocm_plugin_extension = None
49+
try:
50+
from importlib.metadata import distributions
51+
for dist in distributions():
52+
name = dist.metadata.get('Name', '')
53+
if name.startswith('jax-rocm') and name.endswith('-plugin'):
54+
module_name = name.replace('-', '_')
55+
try:
56+
rocm_plugin_extension = importlib.import_module(
57+
f'{module_name}.rocm_plugin_extension'
58+
)
59+
break
60+
except ImportError:
61+
continue
62+
except Exception:
63+
pass
64+
4765

4866
def _supports_buffer_protocol(obj):
4967
try:
@@ -218,11 +236,17 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
218236
object = object.__jax_array__()
219237
elif hasattr(object, '__cuda_array_interface__'):
220238
cai = object.__cuda_array_interface__
221-
backend = xla_bridge.get_backend("cuda")
222-
if cuda_plugin_extension is None:
239+
backend = xla_bridge.get_backend()
240+
if 'rocm' in backend.platform_version.lower():
241+
gpu_plugin_extension = rocm_plugin_extension
242+
elif 'cuda' in backend.platform_version.lower():
243+
gpu_plugin_extension = cuda_plugin_extension
244+
else:
245+
gpu_plugin_extension = None
246+
if gpu_plugin_extension is None:
223247
device_id = None
224248
else:
225-
device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0])
249+
device_id = gpu_plugin_extension.get_device_ordinal(cai["data"][0])
226250
object = xc._xla.cuda_array_interface_to_buffer(
227251
cai=cai, gpu_backend=backend, device_id=device_id)
228252

jaxlib/py_array.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -985,9 +985,10 @@ nb::dict PyArray::CudaArrayInterface() {
985985
ifrt::Array* ifrt_array = arr.ifrt_array();
986986
std::optional<xla::Shape>& scratch = arr.GetStorage().dynamic_shape;
987987
auto* pjrt_buffer = GetPjrtBuffer(ifrt_array);
988-
if (pjrt_buffer->client()->platform_id() != xla::CudaId()) {
988+
if (pjrt_buffer->client()->platform_id() != xla::CudaId() &&
989+
pjrt_buffer->client()->platform_id() != xla::RocmId()) {
989990
throw nb::attribute_error(
990-
"__cuda_array_interface__ is only defined for NVidia GPU buffers.");
991+
"__cuda_array_interface__ is only defined for GPU buffers.");
991992
}
992993
if (pjrt_buffer->IsTuple()) {
993994
throw nb::attribute_error(

tests/array_interoperability_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def testCuPyToJax(self, shape, dtype):
392392
shape=all_shapes,
393393
dtype=jtu.dtypes.supported(cuda_array_interface_dtypes),
394394
)
395-
@jtu.run_on_devices("cuda")
395+
@jtu.run_on_devices("gpu")
396396
def testCaiToJax(self, shape, dtype):
397397
dtype = np.dtype(dtype)
398398

@@ -401,7 +401,7 @@ def testCaiToJax(self, shape, dtype):
401401

402402
# using device with highest device_id for testing the correctness
403403
# of detecting the device id from a pointer value
404-
device = jax.devices('cuda')[-1]
404+
device = jax.devices('gpu')[-1]
405405
with jax.default_device(device):
406406
y = jnp.array(x, dtype=dtype)
407407
# TODO(parkers): Remove after setting 'stream' properly below.

0 commit comments

Comments
 (0)