Skip to content

Commit eb12664

Browse files
committed
fix: support exact integers as floats for IntCat
1 parent 23d886f commit eb12664

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

include/bh_python/register_axis.hpp

+22-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,27 @@ auto vectorize_index(T input) {
4747
#define BHP_NOEXCEPT_17
4848
#endif
4949

50+
namespace detail {
51+
template <class T>
52+
decltype(auto) axis_cast(py::handle x) {
53+
return special_cast<T>(x);
54+
}
55+
56+
template <>
57+
inline decltype(auto) axis_cast<int>(py::handle x) {
58+
if(py::isinstance<int>(x))
59+
return py::cast<int>(x);
60+
61+
auto val = py::cast<float>(x);
62+
auto ival = static_cast<int>(val);
63+
64+
if(static_cast<float>(ival) == val)
65+
return ival;
66+
67+
throw py::type_error(py::str("cannot cast {} to int").format(val));
68+
}
69+
} // namespace detail
70+
5071
// we overload vectorize index for category axis
5172
template <class T, class Options>
5273
auto vectorize_index(int (bh::axis::category<T, metadata_t, Options>::*pindex)(const T&)
@@ -56,7 +77,7 @@ auto vectorize_index(int (bh::axis::category<T, metadata_t, Options>::*pindex)(c
5677
auto index = std::mem_fn(pindex);
5778

5879
if(detail::is_value<T>(arg)) {
59-
auto index_value = index(self, detail::special_cast<T>(arg));
80+
auto index_value = index(self, detail::axis_cast<T>(arg));
6081
if(index_value >= self.size())
6182
throw pybind11::key_error(py::str("{!r} not in axis").format(arg));
6283
return py::cast(index_value);

0 commit comments

Comments
 (0)