Commit 4d25b21
committed
Enable testCaiToJax in array interoperability test on ROCm platform
This change enables the tests/array_interoperability_test.py::CudaArrayInterfaceTest::testCaiToJax on ROCm GPUs:
1. jaxlib/py_array.cc: Add ROCm platform check alongside CUDA for
__cuda_array_interface__ property support.
2. jax/_src/numpy/array_constructors.py:
- Add ROCm plugin extension discovery
3. tests/array_interoperability_test.py: Change testCaiToJax to use
@jtu.run_on_devices("gpu") to run on both CUDA and ROCm.1 parent 191504b commit 4d25b21
File tree
3 files changed
+32
-7
lines changed- jaxlib
- jax/_src/numpy
- tests
3 files changed
+32
-7
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
44 | 44 | | |
45 | 45 | | |
46 | 46 | | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
47 | 65 | | |
48 | 66 | | |
49 | 67 | | |
| |||
218 | 236 | | |
219 | 237 | | |
220 | 238 | | |
221 | | - | |
222 | | - | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
223 | 247 | | |
224 | 248 | | |
225 | | - | |
| 249 | + | |
226 | 250 | | |
227 | 251 | | |
228 | 252 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
985 | 985 | | |
986 | 986 | | |
987 | 987 | | |
988 | | - | |
| 988 | + | |
| 989 | + | |
989 | 990 | | |
990 | | - | |
| 991 | + | |
991 | 992 | | |
992 | 993 | | |
993 | 994 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
392 | 392 | | |
393 | 393 | | |
394 | 394 | | |
395 | | - | |
| 395 | + | |
396 | 396 | | |
397 | 397 | | |
398 | 398 | | |
| |||
401 | 401 | | |
402 | 402 | | |
403 | 403 | | |
404 | | - | |
| 404 | + | |
405 | 405 | | |
406 | 406 | | |
407 | 407 | | |
| |||
0 commit comments