@@ -126,7 +126,7 @@ void init_array(nb::module_& m) {
126126 m.attr (" float32" ) = nb::cast (float32);
127127 m.attr (" bfloat16" ) = nb::cast (bfloat16);
128128 m.attr (" complex64" ) = nb::cast (complex64);
129- nb::class_ <Dtype::Category>(
129+ nb::enum_ <Dtype::Category>(
130130 m,
131131 " DtypeCategory" ,
132132 R"pbdoc(
@@ -165,16 +165,16 @@ void init_array(nb::module_& m) {
165165 * :ref:`complex64 <data_types>`
166166
167167 See also :func:`~mlx.core.issubdtype`.
168- )pbdoc" );
169- m. attr (" complexfloating" ) = nb::cast ( complexfloating);
170- m. attr (" floating" ) = nb::cast ( floating);
171- m. attr (" inexact" ) = nb::cast ( inexact);
172- m. attr (" signedinteger" ) = nb::cast ( signedinteger);
173- m. attr (" unsignedinteger" ) = nb::cast ( unsignedinteger);
174- m. attr (" integer" ) = nb::cast ( integer);
175- m. attr (" number" ) = nb::cast ( number);
176- m. attr (" generic" ) = nb::cast ( generic);
177-
168+ )pbdoc" )
169+ . value (" complexfloating" , complexfloating)
170+ . value (" floating" , floating)
171+ . value (" inexact" , inexact)
172+ . value (" signedinteger" , signedinteger)
173+ . value (" unsignedinteger" , unsignedinteger)
174+ . value (" integer" , integer)
175+ . value (" number" , number)
176+ . value (" generic" , generic)
177+ . export_values ();
178178 nb::class_<ArrayAt>(
179179 m,
180180 " _ArrayAt" ,
0 commit comments