@@ -16,6 +16,8 @@ limitations under the License.
1616#ifndef ML_DTYPES_INT4_NUMPY_H_
1717#define ML_DTYPES_INT4_NUMPY_H_
1818
19+ #include < type_traits>
20+
1921// Must be included first
2022// clang-format off
2123#include " _src/numpy.h"
@@ -28,6 +30,8 @@ limitations under the License.
2830
2931namespace ml_dtypes {
3032
33+ constexpr char kOutOfRange [] = " out of range value cannot be converted to int4" ;
34+
3135template <typename T>
3236struct Int4TypeDescriptor {
3337 static int Dtype () { return npy_type; }
@@ -114,8 +118,7 @@ bool CastToInt4(PyObject* arg, T* output) {
114118 }
115119 if (d < static_cast <double >(T::lowest ()) ||
116120 d > static_cast <double >(T::highest ())) {
117- PyErr_SetString (PyExc_OverflowError,
118- " out of range value cannot be converted to int4" );
121+ PyErr_SetString (PyExc_OverflowError, kOutOfRange );
119122 }
120123 *output = T (d);
121124 return true ;
@@ -131,9 +134,37 @@ bool CastToInt4(PyObject* arg, T* output) {
131134 if (PyArray_IsScalar (arg, Integer)) {
132135 int64_t v;
133136 PyArray_CastScalarToCtype (arg, &v, PyArray_DescrFromType (NPY_INT64));
137+
138+ if (!(std::numeric_limits<T>::min () <= v &&
139+ v <= std::numeric_limits<T>::max ())) {
140+ PyErr_SetString (PyExc_OverflowError, kOutOfRange );
141+ return false ;
142+ }
134143 *output = T (v);
135144 return true ;
136145 }
146+ if (PyArray_IsScalar (arg, Float)) {
147+ float f;
148+ PyArray_ScalarAsCtype (arg, &f);
149+ if (!(std::numeric_limits<T>::min () <= f &&
150+ f <= std::numeric_limits<T>::max ())) {
151+ PyErr_SetString (PyExc_OverflowError, kOutOfRange );
152+ return false ;
153+ }
154+ *output = T (static_cast <::int8_t >(f));
155+ return true ;
156+ }
157+ if (PyArray_IsScalar (arg, Double)) {
158+ double d;
159+ PyArray_ScalarAsCtype (arg, &d);
160+ if (!(std::numeric_limits<T>::min () <= d &&
161+ d <= std::numeric_limits<T>::max ())) {
162+ PyErr_SetString (PyExc_OverflowError, kOutOfRange );
163+ return false ;
164+ }
165+ *output = T (static_cast <::int8_t >(d));
166+ return true ;
167+ }
137168 return false ;
138169}
139170
@@ -652,7 +683,41 @@ bool RegisterInt4Casts() {
652683 }
653684
654685 // Safe casts from T to other types
655- // TODO(phawkins): add integer types
686+ if (PyArray_RegisterCanCast (&TypeDescriptor<T>::npy_descr, NPY_INT8,
687+ NPY_NOSCALAR) < 0 ) {
688+ return false ;
689+ }
690+ if (PyArray_RegisterCanCast (&TypeDescriptor<T>::npy_descr, NPY_INT16,
691+ NPY_NOSCALAR) < 0 ) {
692+ return false ;
693+ }
694+ if (PyArray_RegisterCanCast (&TypeDescriptor<T>::npy_descr, NPY_INT32,
695+ NPY_NOSCALAR) < 0 ) {
696+ return false ;
697+ }
698+ if (PyArray_RegisterCanCast (&TypeDescriptor<T>::npy_descr, NPY_INT64,
699+ NPY_NOSCALAR) < 0 ) {
700+ return false ;
701+ }
702+
703+ if (std::is_same_v<uint4, T>) {
704+ if (PyArray_RegisterCanCast (&TypeDescriptor<T>::npy_descr, NPY_UINT8,
705+ NPY_NOSCALAR) < 0 ) {
706+ return false ;
707+ }
708+ if (PyArray_RegisterCanCast (&TypeDescriptor<T>::npy_descr, NPY_UINT16,
709+ NPY_NOSCALAR) < 0 ) {
710+ return false ;
711+ }
712+ if (PyArray_RegisterCanCast (&TypeDescriptor<T>::npy_descr, NPY_UINT32,
713+ NPY_NOSCALAR) < 0 ) {
714+ return false ;
715+ }
716+ if (PyArray_RegisterCanCast (&TypeDescriptor<T>::npy_descr, NPY_UINT64,
717+ NPY_NOSCALAR) < 0 ) {
718+ return false ;
719+ }
720+ }
656721 if (PyArray_RegisterCanCast (&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
657722 NPY_NOSCALAR) < 0 ) {
658723 return false ;
0 commit comments