Skip to content

Commit 4c1dfa5

Browse files
authored
xor op on arrays (#1875)
1 parent 5274c3c commit 4c1dfa5

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

python/src/array.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,38 @@ void init_array(nb::module_& m) {
878878
},
879879
"other"_a,
880880
nb::rv_policy::none)
881+
.def(
882+
"__xor__",
883+
[](const mx::array& a, const ScalarOrArray v) {
884+
if (!is_comparable_with_array(v)) {
885+
throw_invalid_operation("bitwise xor", v);
886+
}
887+
auto b = to_array(v, a.dtype());
888+
if (mx::issubdtype(a.dtype(), mx::inexact) ||
889+
mx::issubdtype(b.dtype(), mx::inexact)) {
890+
throw std::invalid_argument(
891+
"Floating point types not allowed with bitwise xor.");
892+
}
893+
return mx::bitwise_xor(a, b);
894+
},
895+
"other"_a)
896+
.def(
897+
"__ixor__",
898+
[](mx::array& a, const ScalarOrArray v) -> mx::array& {
899+
if (!is_comparable_with_array(v)) {
900+
throw_invalid_operation("inplace bitwise xor", v);
901+
}
902+
auto b = to_array(v, a.dtype());
903+
if (mx::issubdtype(a.dtype(), mx::inexact) ||
904+
mx::issubdtype(b.dtype(), mx::inexact)) {
905+
throw std::invalid_argument(
906+
"Floating point types not allowed bitwise xor.");
907+
}
908+
a.overwrite_descriptor(mx::bitwise_xor(a, b));
909+
return a;
910+
},
911+
"other"_a,
912+
nb::rv_policy::none)
881913
.def("__int__", [](mx::array& a) { return nb::int_(to_scalar(a)); })
882914
.def("__float__", [](mx::array& a) { return nb::float_(to_scalar(a)); })
883915
.def(

python/tests/test_array.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,6 +1725,7 @@ def test_logical_overloads(self):
17251725
self.assertEqual((mx.array(True) | False).item(), True)
17261726
self.assertEqual((mx.array(False) | False).item(), False)
17271727
self.assertEqual((~mx.array(False)).item(), True)
1728+
self.assertEqual((mx.array(False) ^ True).item(), True)
17281729

17291730
def test_inplace(self):
17301731
iops = [
@@ -1734,6 +1735,7 @@ def test_inplace(self):
17341735
"__ifloordiv__",
17351736
"__imod__",
17361737
"__ipow__",
1738+
"__ixor__",
17371739
]
17381740

17391741
for op in iops:
@@ -1773,6 +1775,10 @@ def test_inplace(self):
17731775
b @= a
17741776
self.assertTrue(mx.array_equal(a, b))
17751777

1778+
a = mx.array(False)
1779+
a ^= True
1780+
self.assertEqual(a.item(), True)
1781+
17761782
def test_inplace_preserves_ids(self):
17771783
a = mx.array([1.0])
17781784
orig_id = id(a)

0 commit comments

Comments
 (0)