Skip to content

Commit 7e1678c

Browse files
committed
fix zero dim broadcast
1 parent f75c26a commit 7e1678c

4 files changed

Lines changed: 11 additions & 6 deletions

File tree

python/jittor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.1.7.9'
10+
__version__ = '1.1.7.10'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler

python/jittor/test/test_broadcast_to_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,11 @@ def test_negtive_dim2(self):
133133
assert (a.broadcast(b, [-1]).data == [[1,1],[2,2]]).all()
134134
assert (a.broadcast(b, [-2]).data == [[1,2],[1,2]]).all()
135135

136+
def test_zero_dim(self):
137+
a = jt.array(1.0)
138+
b = a.broadcast([0])
139+
assert b.shape == [0]
140+
136141

137142
if __name__ == "__main__":
138143
unittest.main()

src/ops/broadcast_to_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
6363
set_type(OpType::broadcast);
6464
CHECKop(shape.size(),>,0u) << "Number of shape should greater than 0.";
6565
for (auto v : shape)
66-
CHECKop(v,>,0u) << "Shape should greater than 0.";
66+
CHECKop(v,>=,0u) << "Shape should greater than 0.";
6767
z = create_output(nullptr, x->dtype());
6868
bcast_mask = 0;
6969
keepdims_mask = 0;
@@ -78,7 +78,7 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
7878
bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
7979
if (x->shape.size() < shape.size()) return true;
8080
for (uint i=shape.size()-1, j=x->shape.size()-1; i<shape.size(); i--,j--)
81-
if (x->shape[j]< 0 || x->shape[j] < shape[i]) return true;
81+
if (x->shape[j]< 0 || (x->shape[j] != shape[i] && shape[i] != 1)) return true;
8282
return false;
8383
}
8484

src/ops/reindex_reduce_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ void ReindexReduceOp::infer_shape() {
5757
CHECKop(shape.size(),==,indexes.size()) << "Number of shape and indexes should be the same.";
5858
CHECK(shape.size()) << "Number of shape should greater than 0.";
5959
for (auto v : shape)
60-
CHECKop(v,>,0u) << "Shape should greater than 0.";
60+
CHECKop(v,>=,0u) << "Shape should greater than 0.";
6161
x->set_shape(shape);
62-
CHECKop(x->size,>,0u);
63-
CHECKop(y->size,>,0u);
62+
CHECKop(x->size,>=,0u);
63+
CHECKop(y->size,>=,0u);
6464
}
6565

6666
void ReindexReduceOp::jit_prepare() {

0 commit comments

Comments
 (0)