Skip to content

Commit bbeedd4

Browse files
ChromeHeartsThe ml_dtypes Authors
authored andcommitted
Added support for int4 casting to wider integers such as int8
Addes support to cast np.float32 and np.float64 into int4 PiperOrigin-RevId: 567667633
1 parent fc69958 commit bbeedd4

File tree

4 files changed

+77
-6
lines changed

4 files changed

+77
-6
lines changed

CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
## [0.3.1] - 2023-09-22
27+
28+
* Added support for int4 casting to wider integers such as int8
29+
* Addes support to cast np.float32 and np.float64 into int4
30+
2631
## [0.3.0] - 2023-09-19
2732

2833
* Dropped support for Python 3.8, following [NEP 29].
@@ -44,7 +49,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
4449

4550
* Initial release
4651

47-
[Unreleased]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.0...HEAD
52+
[Unreleased]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.1...HEAD
53+
[0.3.1]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.0...v0.3.1
4854
[0.3.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.2.0...v0.3.0
4955
[0.2.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.1.0...v0.2.0
5056
[0.1.0]: https://github.com/jax-ml/ml_dtypes/releases/tag/v0.1.0

ml_dtypes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = '0.3.0' # Keep in sync with pyproject.toml:version
15+
__version__ = '0.3.1' # Keep in sync with pyproject.toml:version
1616
__all__ = [
1717
'__version__',
1818
'bfloat16',

ml_dtypes/_src/int4_numpy.h

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License.
1616
#ifndef ML_DTYPES_INT4_NUMPY_H_
1717
#define ML_DTYPES_INT4_NUMPY_H_
1818

19+
#include <type_traits>
20+
1921
// Must be included first
2022
// clang-format off
2123
#include "_src/numpy.h"
@@ -28,6 +30,8 @@ limitations under the License.
2830

2931
namespace ml_dtypes {
3032

33+
constexpr char kOutOfRange[] = "out of range value cannot be converted to int4";
34+
3135
template <typename T>
3236
struct Int4TypeDescriptor {
3337
static int Dtype() { return npy_type; }
@@ -114,8 +118,7 @@ bool CastToInt4(PyObject* arg, T* output) {
114118
}
115119
if (d < static_cast<double>(T::lowest()) ||
116120
d > static_cast<double>(T::highest())) {
117-
PyErr_SetString(PyExc_OverflowError,
118-
"out of range value cannot be converted to int4");
121+
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
119122
}
120123
*output = T(d);
121124
return true;
@@ -131,9 +134,37 @@ bool CastToInt4(PyObject* arg, T* output) {
131134
if (PyArray_IsScalar(arg, Integer)) {
132135
int64_t v;
133136
PyArray_CastScalarToCtype(arg, &v, PyArray_DescrFromType(NPY_INT64));
137+
138+
if (!(std::numeric_limits<T>::min() <= v &&
139+
v <= std::numeric_limits<T>::max())) {
140+
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
141+
return false;
142+
}
134143
*output = T(v);
135144
return true;
136145
}
146+
if (PyArray_IsScalar(arg, Float)) {
147+
float f;
148+
PyArray_ScalarAsCtype(arg, &f);
149+
if (!(std::numeric_limits<T>::min() <= f &&
150+
f <= std::numeric_limits<T>::max())) {
151+
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
152+
return false;
153+
}
154+
*output = T(static_cast<::int8_t>(f));
155+
return true;
156+
}
157+
if (PyArray_IsScalar(arg, Double)) {
158+
double d;
159+
PyArray_ScalarAsCtype(arg, &d);
160+
if (!(std::numeric_limits<T>::min() <= d &&
161+
d <= std::numeric_limits<T>::max())) {
162+
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
163+
return false;
164+
}
165+
*output = T(static_cast<::int8_t>(d));
166+
return true;
167+
}
137168
return false;
138169
}
139170

@@ -652,7 +683,41 @@ bool RegisterInt4Casts() {
652683
}
653684

654685
// Safe casts from T to other types
655-
// TODO(phawkins): add integer types
686+
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT8,
687+
NPY_NOSCALAR) < 0) {
688+
return false;
689+
}
690+
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT16,
691+
NPY_NOSCALAR) < 0) {
692+
return false;
693+
}
694+
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT32,
695+
NPY_NOSCALAR) < 0) {
696+
return false;
697+
}
698+
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT64,
699+
NPY_NOSCALAR) < 0) {
700+
return false;
701+
}
702+
703+
if (std::is_same_v<uint4, T>) {
704+
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT8,
705+
NPY_NOSCALAR) < 0) {
706+
return false;
707+
}
708+
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT16,
709+
NPY_NOSCALAR) < 0) {
710+
return false;
711+
}
712+
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT32,
713+
NPY_NOSCALAR) < 0) {
714+
return false;
715+
}
716+
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT64,
717+
NPY_NOSCALAR) < 0) {
718+
return false;
719+
}
720+
}
656721
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
657722
NPY_NOSCALAR) < 0) {
658723
return false;

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "ml_dtypes"
3-
version = "0.3.0" # Keep in sync with ml_dtypes/__init__.py:__version__
3+
version = "0.3.1" # Keep in sync with ml_dtypes/__init__.py:__version__
44
description = ""
55
readme = "README.md"
66
requires-python = ">=3.9"

0 commit comments

Comments
 (0)