@@ -878,6 +878,38 @@ void init_array(nb::module_& m) {
878878 },
879879 " other" _a,
880880 nb::rv_policy::none)
881+ .def (
882+ " __xor__" ,
883+ [](const mx::array& a, const ScalarOrArray v) {
884+ if (!is_comparable_with_array (v)) {
885+ throw_invalid_operation (" bitwise xor" , v);
886+ }
887+ auto b = to_array (v, a.dtype ());
888+ if (mx::issubdtype (a.dtype (), mx::inexact) ||
889+ mx::issubdtype (b.dtype (), mx::inexact)) {
890+ throw std::invalid_argument (
891+ " Floating point types not allowed with bitwise xor." );
892+ }
893+ return mx::bitwise_xor (a, b);
894+ },
895+ " other" _a)
896+ .def (
897+ " __ixor__" ,
898+ [](mx::array& a, const ScalarOrArray v) -> mx::array& {
899+ if (!is_comparable_with_array (v)) {
900+ throw_invalid_operation (" inplace bitwise xor" , v);
901+ }
902+ auto b = to_array (v, a.dtype ());
903+ if (mx::issubdtype (a.dtype (), mx::inexact) ||
904+ mx::issubdtype (b.dtype (), mx::inexact)) {
905+ throw std::invalid_argument (
906+ " Floating point types not allowed bitwise xor." );
907+ }
908+ a.overwrite_descriptor (mx::bitwise_xor (a, b));
909+ return a;
910+ },
911+ " other" _a,
912+ nb::rv_policy::none)
881913 .def (" __int__" , [](mx::array& a) { return nb::int_ (to_scalar (a)); })
882914 .def (" __float__" , [](mx::array& a) { return nb::float_ (to_scalar (a)); })
883915 .def (
0 commit comments