Skip to content

Commit 05a6732

Browse files
committed
fix vary shape setitem
1 parent d31cdab commit 05a6732

3 files changed

Lines changed: 16 additions & 1 deletion

File tree

python/jittor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.2.0.5'
10+
__version__ = '1.2.0.6'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler

python/jittor/test/test_slice.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def numpy_grad(vars):
133133
assert np.allclose(da.numpy(), nda, atol = 1e-3)
134134
assert np.allclose(db.numpy(), ndb, atol = 1e-3)
135135

136+
def test_vary_shape_setitem(self):
137+
a = jt.array([1,2,3,4,5])
138+
b = jt.array([1,2,3,4,5])
139+
c = jt.where(b>3)
140+
a[c] = 0
141+
assert (a.data == [1,2,3,0,0]).all()
142+
136143

137144

138145
if __name__ == "__main__":

src/ops/setitem_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
176176
}
177177

178178
void SetitemOp::jit_prepare() {
179+
for (int i=0; i<o_shape.size(); i++)
180+
if (o_shape[i]<0) {
181+
// because output shape is inferd, check in
182+
// executor not work
183+
// reinfer shape if o_shape has vary shape
184+
infer_shape();
185+
break;
186+
}
179187
auto data = input(1);
180188
add_jit_define("OP", op);
181189
add_jit_define("Td", data->dtype());

0 commit comments

Comments
 (0)