File tree Expand file tree Collapse file tree 3 files changed +31
-1
lines changed
Expand file tree Collapse file tree 3 files changed +31
-1
lines changed Original file line number Diff line number Diff line change 33Conversion to NumPy and Other Frameworks
44========================================
55
6- MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html >`_.
6+ MLX array supports conversion between other frameworks with either:
7+
8+ * The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html >`_.
9+ * `DLPack <https://dmlc.github.io/dlpack/latest/ >`_.
10+
711Let's convert an array to NumPy and back.
812
913.. code-block :: python
Original file line number Diff line number Diff line change 1010#include < nanobind/stl/variant.h>
1111#include < nanobind/stl/vector.h>
1212
13+ #include " mlx/backend/metal/metal.h"
1314#include " python/src/buffer.h"
1415#include " python/src/convert.h"
1516#include " python/src/indexing.h"
1617#include " python/src/utils.h"
1718
19+ #include " mlx/device.h"
1820#include " mlx/ops.h"
1921#include " mlx/transforms.h"
2022#include " mlx/utils.h"
@@ -353,6 +355,17 @@ void init_array(nb::module_& m) {
353355 new (&arr) array (nd_array_to_mlx (state, std::nullopt ));
354356 })
355357 .def (" __dlpack__" , [](const array& a) { return mlx_to_dlpack (a); })
358+ .def (
359+ " __dlpack_device__" ,
360+ [](const array& a) {
361+ if (metal::is_available ()) {
362+ // Metal device is available
363+ return nb::make_tuple (8 , 0 );
364+ } else {
365+ // CPU device
366+ return nb::make_tuple (1 , 0 );
367+ }
368+ })
356369 .def (" __copy__" , [](const array& self) { return array (self); })
357370 .def (
358371 " __deepcopy__" ,
Original file line number Diff line number Diff line change @@ -161,6 +161,19 @@ def test_list_not_equals_array(self):
161161 self .assertTrue (a != b )
162162 self .assertTrue (a != c )
163163
164+ def test_dlx_device_type (self ):
165+ a = mx .array ([1 , 2 , 3 ])
166+ device_type , device_id = a .__dlpack_device__ ()
167+ self .assertIn (device_type , [1 , 8 ])
168+ self .assertEqual (device_id , 0 )
169+
170+ if device_type == 8 :
171+ # Additional check if Metal is supposed to be available
172+ self .assertTrue (mx .metal .is_available ())
173+ elif device_type == 1 :
174+ # Additional check if CPU is the fallback
175+ self .assertFalse (mx .metal .is_available ())
176+
164177 def test_tuple_not_equals_array (self ):
165178 a = mx .array ([1 , 2 , 3 ])
166179 b = (1 , 2 , 3 )
You can’t perform that action at this time.
0 commit comments