Skip to content

Commit 81dd33a

Browse files
authored
allow conversion to dlpack (#1120)
1 parent 8b76571 commit 81dd33a

File tree

4 files changed

+41
-26
lines changed

4 files changed

+41
-26
lines changed

python/src/array.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -669,19 +669,14 @@ void init_array(nb::module_& m) {
669669
return a.shape(0);
670670
})
671671
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
672-
.def(
673-
"__getstate__",
674-
[](const array& a) {
675-
if (a.dtype() == bfloat16) {
676-
}
677-
return mlx_to_np_array(a);
678-
})
672+
.def("__getstate__", &mlx_to_np_array)
679673
.def(
680674
"__setstate__",
681675
[](array& arr,
682676
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) {
683677
new (&arr) array(nd_array_to_mlx(state, std::nullopt));
684678
})
679+
.def("__dlpack__", [](const array& a) { return mlx_to_dlpack(a); })
685680
.def("__copy__", [](const array& self) { return array(self); })
686681
.def(
687682
"__deepcopy__",

python/src/convert.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ array nd_array_to_mlx(
100100
}
101101
}
102102

103-
template <typename Lib, typename T>
104-
nb::ndarray<Lib> mlx_to_nd_array(
103+
template <typename T, typename... NDParams>
104+
nb::ndarray<NDParams...> mlx_to_nd_array_impl(
105105
array a,
106106
std::optional<nb::dlpack::dtype> t = {}) {
107107
{
@@ -110,47 +110,51 @@ nb::ndarray<Lib> mlx_to_nd_array(
110110
}
111111
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
112112
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
113-
return nb::ndarray<Lib>(
113+
return nb::ndarray<NDParams...>(
114114
a.data<T>(),
115115
a.ndim(),
116116
shape.data(),
117-
nb::handle(),
117+
nb::none(),
118118
strides.data(),
119119
t.value_or(nb::dtype<T>()));
120120
}
121121

122-
template <typename Lib>
123-
nb::ndarray<Lib> mlx_to_nd_array(const array& a) {
122+
template <typename... NDParams>
123+
nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
124124
switch (a.dtype()) {
125125
case bool_:
126-
return mlx_to_nd_array<Lib, bool>(a);
126+
return mlx_to_nd_array_impl<bool, NDParams...>(a);
127127
case uint8:
128-
return mlx_to_nd_array<Lib, uint8_t>(a);
128+
return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);
129129
case uint16:
130-
return mlx_to_nd_array<Lib, uint16_t>(a);
130+
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
131131
case uint32:
132-
return mlx_to_nd_array<Lib, uint32_t>(a);
132+
return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);
133133
case uint64:
134-
return mlx_to_nd_array<Lib, uint64_t>(a);
134+
return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);
135135
case int8:
136-
return mlx_to_nd_array<Lib, int8_t>(a);
136+
return mlx_to_nd_array_impl<int8_t, NDParams...>(a);
137137
case int16:
138-
return mlx_to_nd_array<Lib, int16_t>(a);
138+
return mlx_to_nd_array_impl<int16_t, NDParams...>(a);
139139
case int32:
140-
return mlx_to_nd_array<Lib, int32_t>(a);
140+
return mlx_to_nd_array_impl<int32_t, NDParams...>(a);
141141
case int64:
142-
return mlx_to_nd_array<Lib, int64_t>(a);
142+
return mlx_to_nd_array_impl<int64_t, NDParams...>(a);
143143
case float16:
144-
return mlx_to_nd_array<Lib, float16_t>(a);
144+
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
145145
case bfloat16:
146-
return mlx_to_nd_array<Lib, bfloat16_t>(a, nb::bfloat16);
146+
return mlx_to_nd_array_impl<bfloat16_t, NDParams...>(a, nb::bfloat16);
147147
case float32:
148-
return mlx_to_nd_array<Lib, float>(a);
148+
return mlx_to_nd_array_impl<float, NDParams...>(a);
149149
case complex64:
150-
return mlx_to_nd_array<Lib, std::complex<float>>(a);
150+
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
151151
}
152152
}
153153

154154
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
155155
return mlx_to_nd_array<nb::numpy>(a);
156156
}
157+
158+
nb::ndarray<> mlx_to_dlpack(const array& a) {
159+
return mlx_to_nd_array<>(a);
160+
}

python/src/convert.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ using namespace mlx::core;
1313
array nd_array_to_mlx(
1414
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
1515
std::optional<Dtype> dtype);
16+
1617
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);
18+
nb::ndarray<> mlx_to_dlpack(const array& a);

python/tests/test_array.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,6 +1722,20 @@ def test_add_numpy(self):
17221722
self.assertEqual(z.dtype, mx.int32)
17231723
self.assertEqual(z.item(), 3)
17241724

1725+
def test_dlpack(self):
1726+
x = mx.array(1, dtype=mx.int32)
1727+
y = np.from_dlpack(x)
1728+
self.assertTrue(mx.array_equal(y, x))
1729+
1730+
x = mx.array([[1.0, 2.0], [3.0, 4.0]])
1731+
y = np.from_dlpack(x)
1732+
self.assertTrue(mx.array_equal(y, x))
1733+
1734+
x = mx.arange(16).reshape(4, 4)
1735+
x = x[::2, ::2]
1736+
y = np.from_dlpack(x)
1737+
self.assertTrue(mx.array_equal(y, x))
1738+
17251739

17261740
if __name__ == "__main__":
17271741
unittest.main()

0 commit comments

Comments
 (0)