Skip to content

Commit ab97710

Browse files
venkatramnankVenkat Ramnan Kalyanakumarawni
authored
feat: Added dlpack device (#1165)
* feat: Added dlpack device * feat: Added device_id to dlpack device * feat: Added device_id to dlpack device * doc: updated conversion docs * doc: updated numpy.rst dlpack information * doc: updated numpy.rst dlpack information * Update docs/src/usage/numpy.rst * Update docs/src/usage/numpy.rst --------- Co-authored-by: Venkat Ramnan Kalyanakumar <[email protected]> Co-authored-by: Awni Hannun <[email protected]>
1 parent fd1c081 commit ab97710

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

docs/src/usage/numpy.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
Conversion to NumPy and Other Frameworks
44
========================================
55

6-
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
6+
MLX array supports conversion between other frameworks with either:
7+
8+
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
9+
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
10+
711
Let's convert an array to NumPy and back.
812

913
.. code-block:: python

python/src/array.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
#include <nanobind/stl/variant.h>
1111
#include <nanobind/stl/vector.h>
1212

13+
#include "mlx/backend/metal/metal.h"
1314
#include "python/src/buffer.h"
1415
#include "python/src/convert.h"
1516
#include "python/src/indexing.h"
1617
#include "python/src/utils.h"
1718

19+
#include "mlx/device.h"
1820
#include "mlx/ops.h"
1921
#include "mlx/transforms.h"
2022
#include "mlx/utils.h"
@@ -353,6 +355,17 @@ void init_array(nb::module_& m) {
353355
new (&arr) array(nd_array_to_mlx(state, std::nullopt));
354356
})
355357
.def("__dlpack__", [](const array& a) { return mlx_to_dlpack(a); })
358+
.def(
359+
"__dlpack_device__",
360+
[](const array& a) {
361+
if (metal::is_available()) {
362+
// Metal device is available
363+
return nb::make_tuple(8, 0);
364+
} else {
365+
// CPU device
366+
return nb::make_tuple(1, 0);
367+
}
368+
})
356369
.def("__copy__", [](const array& self) { return array(self); })
357370
.def(
358371
"__deepcopy__",

python/tests/test_array.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,19 @@ def test_list_not_equals_array(self):
161161
self.assertTrue(a != b)
162162
self.assertTrue(a != c)
163163

164+
def test_dlx_device_type(self):
165+
a = mx.array([1, 2, 3])
166+
device_type, device_id = a.__dlpack_device__()
167+
self.assertIn(device_type, [1, 8])
168+
self.assertEqual(device_id, 0)
169+
170+
if device_type == 8:
171+
# Additional check if Metal is supposed to be available
172+
self.assertTrue(mx.metal.is_available())
173+
elif device_type == 1:
174+
# Additional check if CPU is the fallback
175+
self.assertFalse(mx.metal.is_available())
176+
164177
def test_tuple_not_equals_array(self):
165178
a = mx.array([1, 2, 3])
166179
b = (1, 2, 3)

0 commit comments

Comments
 (0)