Skip to content

Commit 431ab5c

Browse files
committed
fix bool setitem and reshape NanoVector
1 parent 40cdd27 commit 431ab5c

13 files changed

Lines changed: 57 additions & 15 deletions

python/jittor/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.1.6.5'
10+
__version__ = '1.1.6.6'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler
@@ -196,6 +196,7 @@ def clean():
196196
gc.collect()
197197

198198
cast = unary
199+
Var.cast = Var.cast
199200

200201
def array(data, dtype=None):
201202
if isinstance(data, core.Var):
@@ -250,15 +251,15 @@ def norm(x, k, dim):
250251

251252
origin_reshape = reshape
252253
def reshape(x, *shape):
253-
if len(shape) == 1 and isinstance(shape[0], Sequence):
254+
if len(shape) == 1 and isinstance(shape[0], (Sequence, NanoVector)):
254255
shape = shape[0]
255256
return origin_reshape(x, shape)
256257
reshape.__doc__ = origin_reshape.__doc__
257258
Var.view = Var.reshape = view = reshape
258259

259260
origin_transpose = transpose
260261
def transpose(x, *dim):
261-
if len(dim) == 1 and isinstance(dim[0], Sequence):
262+
if len(dim) == 1 and isinstance(dim[0], (Sequence, NanoVector)):
262263
dim = dim[0]
263264
return origin_transpose(x, dim)
264265
transpose.__doc__ = origin_transpose.__doc__

python/jittor/contrib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def setitem(x, slices, value):
147147
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
148148
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
149149
value = jt.broadcast(value, xslice)
150+
value = value.cast(x.dtype)
150151
one = jt.broadcast(1, xslice)
151152
if not isinstance(reindex_args[0][0], jt.Var):
152153
reindex_args = (x.shape,) + reindex_args[1:]

python/jittor/test/test_binary_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def test_binary_op(self):
3333
assert (x == 8).all()
3434
x = (jt.array(2) ** jt.array(3)).data
3535
assert (x == 8).all()
36-
a = [1,2,3]
37-
b = [7,10,13]
36+
a = np.array([1,2,3])
37+
b = np.array([7,10,13])
3838
check("logical_and", a, b)
3939
check("logical_or", a, b)
4040
check("logical_xor", a, b)
@@ -79,6 +79,8 @@ def check(op, a, b):
7979

8080
def test_r(self):
8181
def check(op, a, b):
82+
a = np.array(a)
83+
b = np.array(b)
8284
if jt.flags.use_cuda and op == "@":
8385
return
8486
jb = jt.array(b)

python/jittor/test/test_concat_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ def check(tmp, dim=0):
1515
res2 = jt.contrib.concat(tmp, dim=dim)
1616
assert (res1!=res2).data.sum()==0, "concat fail..."
1717
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
18-
check([jt.array(range(24)).reshape((1,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))])
19-
check([jt.array(range(120)).reshape((5,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))])
20-
check([jt.array(range(5)).reshape((5,1)), jt.array(range(1)).reshape((1,1))])
18+
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
19+
check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
20+
check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))])
2121
print('concat success...')
2222

2323
if __name__ == "__main__":

python/jittor/test/test_nano_string.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_type(self):
5050
assert str(jt.NanoString(jt.float32)) == "float32"
5151
assert str(jt.NanoString(jt.float64)) == "float64"
5252
assert str(jt.NanoString(jt.int8)) == "int8"
53-
assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int64"
53+
assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int32"
5454
assert str(jt.NanoString(jt.sum)) == "add"
5555

5656
def get_error_str(call):

python/jittor/test/test_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def test_lived(self):
5353
da, db = jt.grad(c, [a, b])
5454
da.name('da')
5555
db.name('db')
56-
check(5,7,5) # dc, 3, da, 1, db, 1
56+
check(5,6,4) # dc, 3, da, 1, db, 1
5757
del a, b, c
58-
check(2,6,4)
58+
check(2,5,3)
5959
da.sync(), db.sync()
6060
check(2,2,0)
6161
del da, db

python/jittor/test/test_reshape.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def test_flatten(self):
7070
assert a.flatten(1).shape == [2,12]
7171
assert a.flatten(0,-2).shape == [6,4]
7272

73+
def test_reshape_var(self):
74+
a = jt.zeros(10)
75+
b = a.reshape(a.shape)
76+
7377

7478
if __name__ == "__main__":
7579
unittest.main()

python/jittor/test/test_slice.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# ***************************************************************
2+
# Copyright (c) 2020 Jittor.
3+
# Authors:
4+
# Dun Liang <randonlang@gmail.com>.
5+
# All Rights Reserved.
6+
# This file is subject to the terms and conditions defined in
7+
# file 'LICENSE.txt', which is part of this source code package.
8+
# ***************************************************************
9+
import unittest
10+
import jittor as jt
11+
import numpy as np
12+
13+
14+
class TestSlice(unittest.TestCase):
15+
def test_slice_bool(self):
16+
a = jt.zeros(10, "bool")
17+
a[1] = True
18+
a[2] = 1
19+
assert a.dtype == "bool"
20+
print(a)
21+
22+
23+
if __name__ == "__main__":
24+
unittest.main()

python/jittor/test/test_unary_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ def test_unary_op(self):
2525
assert jt.float64(1).data.dtype == "float64"
2626
assert (jt.abs(-1) == 1).data.all()
2727
assert (abs(-jt.float64(1)) == 1).data.all()
28-
a = [-1,2,3,0]
28+
a = np.array([-1,2,3,0])
2929
check("abs", a)
3030
check("negative", a)
3131
check("logical_not", a)
3232
check("bitwise_not", a)
33-
b = [1.1, 2.2, 3.3, 4.4, -1, 0]
33+
b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0])
3434
check("log", a)
3535
check("exp", a)
3636
check("sqrt", a)
@@ -42,7 +42,7 @@ def test_grad(self):
4242
"cos", "arccos", "cosh", "arccosh",
4343
"sigmoid",
4444
]
45-
a = [1.1, 2.2, 3.3, 4.4]
45+
a = np.array([1.1, 2.2, 3.3, 4.4])
4646
for op in ops:
4747
if op == "abs":
4848
b = np.array(a+[-1,])

src/misc/nano_string.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) {
156156
if (dsize==8) return ns_int64;
157157
if (dsize==4) return ns_int32;
158158
if (dsize==2) return ns_int16;
159-
return ns_int8;
159+
return v1;
160160
}
161161
}
162162

0 commit comments

Comments
 (0)