@@ -100,8 +100,8 @@ array nd_array_to_mlx(
100100 }
101101}
102102
103- template <typename Lib , typename T >
104- nb::ndarray<Lib> mlx_to_nd_array (
103+ template <typename T , typename ... NDParams >
104+ nb::ndarray<NDParams...> mlx_to_nd_array_impl (
105105 array a,
106106 std::optional<nb::dlpack::dtype> t = {}) {
107107 {
@@ -110,47 +110,51 @@ nb::ndarray<Lib> mlx_to_nd_array(
110110 }
111111 std::vector<size_t > shape (a.shape ().begin (), a.shape ().end ());
112112 std::vector<int64_t > strides (a.strides ().begin (), a.strides ().end ());
113- return nb::ndarray<Lib >(
113+ return nb::ndarray<NDParams... >(
114114 a.data <T>(),
115115 a.ndim (),
116116 shape.data (),
117- nb::handle (),
117+ nb::none (),
118118 strides.data (),
119119 t.value_or (nb::dtype<T>()));
120120}
121121
122- template <typename Lib >
123- nb::ndarray<Lib > mlx_to_nd_array (const array& a) {
122+ template <typename ... NDParams >
123+ nb::ndarray<NDParams... > mlx_to_nd_array (const array& a) {
124124 switch (a.dtype ()) {
125125 case bool_:
126- return mlx_to_nd_array<Lib, bool >(a);
126+ return mlx_to_nd_array_impl< bool , NDParams... >(a);
127127 case uint8:
128- return mlx_to_nd_array<Lib, uint8_t >(a);
128+ return mlx_to_nd_array_impl< uint8_t , NDParams... >(a);
129129 case uint16:
130- return mlx_to_nd_array<Lib, uint16_t >(a);
130+ return mlx_to_nd_array_impl< uint16_t , NDParams... >(a);
131131 case uint32:
132- return mlx_to_nd_array<Lib, uint32_t >(a);
132+ return mlx_to_nd_array_impl< uint32_t , NDParams... >(a);
133133 case uint64:
134- return mlx_to_nd_array<Lib, uint64_t >(a);
134+ return mlx_to_nd_array_impl< uint64_t , NDParams... >(a);
135135 case int8:
136- return mlx_to_nd_array<Lib, int8_t >(a);
136+ return mlx_to_nd_array_impl< int8_t , NDParams... >(a);
137137 case int16:
138- return mlx_to_nd_array<Lib, int16_t >(a);
138+ return mlx_to_nd_array_impl< int16_t , NDParams... >(a);
139139 case int32:
140- return mlx_to_nd_array<Lib, int32_t >(a);
140+ return mlx_to_nd_array_impl< int32_t , NDParams... >(a);
141141 case int64:
142- return mlx_to_nd_array<Lib, int64_t >(a);
142+ return mlx_to_nd_array_impl< int64_t , NDParams... >(a);
143143 case float16:
144- return mlx_to_nd_array<Lib, float16_t >(a);
144+ return mlx_to_nd_array_impl< float16_t , NDParams... >(a);
145145 case bfloat16:
146- return mlx_to_nd_array<Lib, bfloat16_t >(a, nb::bfloat16);
146+ return mlx_to_nd_array_impl< bfloat16_t , NDParams... >(a, nb::bfloat16);
147147 case float32:
148- return mlx_to_nd_array<Lib, float >(a);
148+ return mlx_to_nd_array_impl< float , NDParams... >(a);
149149 case complex64:
150- return mlx_to_nd_array<Lib, std::complex <float >>(a);
150+ return mlx_to_nd_array_impl< std::complex <float >, NDParams... >(a);
151151 }
152152}
153153
154154nb::ndarray<nb::numpy> mlx_to_np_array (const array& a) {
155155 return mlx_to_nd_array<nb::numpy>(a);
156156}
157+
158+ nb::ndarray<> mlx_to_dlpack (const array& a) {
159+ return mlx_to_nd_array<>(a);
160+ }
0 commit comments