Skip to content

Commit a62b45d

Browse files
committed
add negtive dim support
1 parent d833891 commit a62b45d

5 files changed

Lines changed: 36 additions & 4 deletions

File tree

python/jittor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.1.6.7'
10+
__version__ = '1.1.6.8'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler

python/jittor/test/test_broadcast_to_op.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,5 +120,19 @@ def setUp(self):
120120
def tearDown(self):
121121
jt.flags.use_cuda = 0
122122

123+
124+
class TestBroadcastToOpMisc(unittest.TestCase):
125+
def test_negtive_dim(self):
126+
a = jt.array([1,2])
127+
assert (a.broadcast([2,2], [-1]).data == [[1,1],[2,2]]).all()
128+
assert (a.broadcast([2,2], [-2]).data == [[1,2],[1,2]]).all()
129+
130+
def test_negtive_dim2(self):
131+
a = jt.array([1,2])
132+
b = jt.zeros((2,2))
133+
assert (a.broadcast(b, [-1]).data == [[1,1],[2,2]]).all()
134+
assert (a.broadcast(b, [-2]).data == [[1,2],[1,2]]).all()
135+
136+
123137
if __name__ == "__main__":
124138
unittest.main()

python/jittor/test/test_reduce_op.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,12 @@ def setUp(self):
7070
def tearDown(self):
7171
jt.flags.use_cuda = 0
7272

73+
74+
class TestReduceOpMisc(unittest.TestCase):
75+
def test_negtive_dim(self):
76+
a = jt.array([[1,2],[3,4]])
77+
assert (a.sum(-1).data == [3,7]).all()
78+
assert (a.sum(-2).data == [4,6]).all()
79+
7380
if __name__ == "__main__":
7481
unittest.main()

src/ops/broadcast_to_op.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
2727
z = create_output(NanoVector(), x->dtype());
2828
bcast_mask = 0;
2929
keepdims = 0;
30+
auto ydim = y->shape.size();
3031
if (dims.size()) {
31-
for (auto a : dims) bcast_mask |= 1 << a;
32+
for (auto dim : dims) {
33+
if (dim<0) dim += ydim;
34+
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
35+
bcast_mask |= 1 << dim;
36+
}
3237
} else
3338
keepdims = 1;
3439
}
@@ -62,8 +67,13 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
6267
z = create_output(nullptr, x->dtype());
6368
bcast_mask = 0;
6469
keepdims = 0;
70+
auto ydim = shape.size();
6571
if (dims.size()) {
66-
for (auto a : dims) bcast_mask |= 1 << a;
72+
for (auto dim : dims) {
73+
if (dim<0) dim += ydim;
74+
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
75+
bcast_mask |= 1 << dim;
76+
}
6777
} else
6878
keepdims = 1;
6979
}

src/ops/reduce_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
5757
} else {
5858
reduce_mask = 0;
5959
for (auto dim : dims) {
60-
CHECKop(dim,<,xdim) << "Wrong dims number:" << dims;
60+
if (dim<0) dim += xdim;
61+
CHECK(dim>=0 && dim<xdim) << "Wrong dims number:" << dims;
6162
reduce_mask |= 1<<dim;
6263
}
6364
}

0 commit comments

Comments
 (0)