Skip to content

Commit b7588fd

Browse files
authored
fix inplace to not make a shallow copy (#804)
1 parent f512b90 commit b7588fd

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

python/src/array.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ void init_array(py::module_& m) {
802802
"other"_a)
803803
.def(
804804
"__iadd__",
805-
[](array& a, const ScalarOrArray v) {
805+
[](array& a, const ScalarOrArray v) -> array& {
806806
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
807807
return a;
808808
},
@@ -821,7 +821,7 @@ void init_array(py::module_& m) {
821821
"other"_a)
822822
.def(
823823
"__isub__",
824-
[](array& a, const ScalarOrArray v) {
824+
[](array& a, const ScalarOrArray v) -> array& {
825825
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
826826
return a;
827827
},
@@ -840,7 +840,7 @@ void init_array(py::module_& m) {
840840
"other"_a)
841841
.def(
842842
"__imul__",
843-
[](array& a, const ScalarOrArray v) {
843+
[](array& a, const ScalarOrArray v) -> array& {
844844
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
845845
return a;
846846
},
@@ -859,7 +859,7 @@ void init_array(py::module_& m) {
859859
"other"_a)
860860
.def(
861861
"__itruediv__",
862-
[](array& a, const ScalarOrArray v) {
862+
[](array& a, const ScalarOrArray v) -> array& {
863863
if (!is_floating_point(a.dtype())) {
864864
throw std::invalid_argument(
865865
"In place division cannot cast to non-floating point type.");
@@ -894,7 +894,7 @@ void init_array(py::module_& m) {
894894
"other"_a)
895895
.def(
896896
"__ifloordiv__",
897-
[](array& a, const ScalarOrArray v) {
897+
[](array& a, const ScalarOrArray v) -> array& {
898898
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
899899
return a;
900900
},
@@ -914,7 +914,7 @@ void init_array(py::module_& m) {
914914
"other"_a)
915915
.def(
916916
"__imod__",
917-
[](array& a, const ScalarOrArray v) {
917+
[](array& a, const ScalarOrArray v) -> array& {
918918
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
919919
return a;
920920
},
@@ -980,7 +980,7 @@ void init_array(py::module_& m) {
980980
"other"_a)
981981
.def(
982982
"__imatmul__",
983-
[](array& a, array& other) {
983+
[](array& a, array& other) -> array& {
984984
a.overwrite_descriptor(matmul(a, other));
985985
return a;
986986
},
@@ -999,7 +999,7 @@ void init_array(py::module_& m) {
999999
"other"_a)
10001000
.def(
10011001
"__ipow__",
1002-
[](array& a, const ScalarOrArray v) {
1002+
[](array& a, const ScalarOrArray v) -> array& {
10031003
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
10041004
return a;
10051005
},
@@ -1034,7 +1034,7 @@ void init_array(py::module_& m) {
10341034
"other"_a)
10351035
.def(
10361036
"__iand__",
1037-
[](array& a, const ScalarOrArray v) {
1037+
[](array& a, const ScalarOrArray v) -> array& {
10381038
auto b = to_array(v, a.dtype());
10391039
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
10401040
throw std::invalid_argument(
@@ -1065,7 +1065,7 @@ void init_array(py::module_& m) {
10651065
"other"_a)
10661066
.def(
10671067
"__ior__",
1068-
[](array& a, const ScalarOrArray v) {
1068+
[](array& a, const ScalarOrArray v) -> array& {
10691069
auto b = to_array(v, a.dtype());
10701070
if (is_floating_point(a.dtype()) || is_floating_point(b.dtype())) {
10711071
throw std::invalid_argument(

python/tests/test_array.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,21 @@ def test_inplace(self):
14421442
b @= a
14431443
self.assertTrue(mx.array_equal(a, b))
14441444

1445+
def test_inplace_preserves_ids(self):
1446+
a = mx.array([1.0])
1447+
orig_id = id(a)
1448+
a += mx.array(2.0)
1449+
self.assertEqual(id(a), orig_id)
1450+
1451+
a[0] = 2.0
1452+
self.assertEqual(id(a), orig_id)
1453+
1454+
a -= mx.array(3.0)
1455+
self.assertEqual(id(a), orig_id)
1456+
1457+
a *= mx.array(3.0)
1458+
self.assertEqual(id(a), orig_id)
1459+
14451460
def test_load_from_pickled_np(self):
14461461
a = np.array([1, 2, 3], dtype=np.int32)
14471462
b = pickle.loads(pickle.dumps(a))

0 commit comments

Comments
 (0)