We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 40cdd27 commit 431ab5cCopy full SHA for 431ab5c
13 files changed
python/jittor/__init__.py
@@ -7,7 +7,7 @@
7
# This file is subject to the terms and conditions defined in
8
# file 'LICENSE.txt', which is part of this source code package.
9
# ***************************************************************
10
-__version__ = '1.1.6.5'
+__version__ = '1.1.6.6'
11
from . import lock
12
with lock.lock_scope():
13
from . import compiler
@@ -196,6 +196,7 @@ def clean():
196
gc.collect()
197
198
cast = unary
199
+Var.cast = Var.cast
200
201
def array(data, dtype=None):
202
if isinstance(data, core.Var):
@@ -250,15 +251,15 @@ def norm(x, k, dim):
250
251
252
origin_reshape = reshape
253
def reshape(x, *shape):
- if len(shape) == 1 and isinstance(shape[0], Sequence):
254
+ if len(shape) == 1 and isinstance(shape[0], (Sequence, NanoVector)):
255
shape = shape[0]
256
return origin_reshape(x, shape)
257
reshape.__doc__ = origin_reshape.__doc__
258
Var.view = Var.reshape = view = reshape
259
260
origin_transpose = transpose
261
def transpose(x, *dim):
- if len(dim) == 1 and isinstance(dim[0], Sequence):
262
+ if len(dim) == 1 and isinstance(dim[0], (Sequence, NanoVector)):
263
dim = dim[0]
264
return origin_transpose(x, dim)
265
transpose.__doc__ = origin_transpose.__doc__
python/jittor/contrib.py
@@ -147,6 +147,7 @@ def setitem(x, slices, value):
147
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
148
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
149
value = jt.broadcast(value, xslice)
150
+ value = value.cast(x.dtype)
151
one = jt.broadcast(1, xslice)
152
if not isinstance(reindex_args[0][0], jt.Var):
153
reindex_args = (x.shape,) + reindex_args[1:]
python/jittor/test/test_binary_op.py
@@ -33,8 +33,8 @@ def test_binary_op(self):
33
assert (x == 8).all()
34
x = (jt.array(2) ** jt.array(3)).data
35
36
- a = [1,2,3]
37
- b = [7,10,13]
+ a = np.array([1,2,3])
+ b = np.array([7,10,13])
38
check("logical_and", a, b)
39
check("logical_or", a, b)
40
check("logical_xor", a, b)
@@ -79,6 +79,8 @@ def check(op, a, b):
79
80
def test_r(self):
81
def check(op, a, b):
82
+ a = np.array(a)
83
+ b = np.array(b)
84
if jt.flags.use_cuda and op == "@":
85
return
86
jb = jt.array(b)
python/jittor/test/test_concat_op.py
@@ -15,9 +15,9 @@ def check(tmp, dim=0):
15
res2 = jt.contrib.concat(tmp, dim=dim)
16
assert (res1!=res2).data.sum()==0, "concat fail..."
17
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
18
- check([jt.array(range(24)).reshape((1,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))])
19
- check([jt.array(range(120)).reshape((5,2,3,4)), jt.array(range(24)).reshape((1,2,3,4))])
20
- check([jt.array(range(5)).reshape((5,1)), jt.array(range(1)).reshape((1,1))])
+ check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
+ check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
+ check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))])
21
print('concat success...')
22
23
if __name__ == "__main__":
python/jittor/test/test_nano_string.py
@@ -50,7 +50,7 @@ def test_type(self):
50
assert str(jt.NanoString(jt.float32)) == "float32"
51
assert str(jt.NanoString(jt.float64)) == "float64"
52
assert str(jt.NanoString(jt.int8)) == "int8"
53
- assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int64"
+ assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int32"
54
assert str(jt.NanoString(jt.sum)) == "add"
55
56
def get_error_str(call):
python/jittor/test/test_node.py
@@ -53,9 +53,9 @@ def test_lived(self):
da, db = jt.grad(c, [a, b])
da.name('da')
db.name('db')
- check(5,7,5) # dc, 3, da, 1, db, 1
+ check(5,6,4) # dc, 3, da, 1, db, 1
57
del a, b, c
58
- check(2,6,4)
+ check(2,5,3)
59
da.sync(), db.sync()
60
check(2,2,0)
61
del da, db
python/jittor/test/test_reshape.py
@@ -70,6 +70,10 @@ def test_flatten(self):
70
assert a.flatten(1).shape == [2,12]
71
assert a.flatten(0,-2).shape == [6,4]
72
73
+ def test_reshape_var(self):
74
+ a = jt.zeros(10)
75
+ b = a.reshape(a.shape)
76
+
77
78
unittest.main()
python/jittor/test/test_slice.py
@@ -0,0 +1,24 @@
1
+# ***************************************************************
2
+# Copyright (c) 2020 Jittor.
3
+# Authors:
4
+# Dun Liang <randonlang@gmail.com>.
5
+# All Rights Reserved.
6
+# This file is subject to the terms and conditions defined in
+# file 'LICENSE.txt', which is part of this source code package.
+import unittest
+import jittor as jt
+import numpy as np
14
+class TestSlice(unittest.TestCase):
+ def test_slice_bool(self):
+ a = jt.zeros(10, "bool")
+ a[1] = True
+ a[2] = 1
+ assert a.dtype == "bool"
+ print(a)
+if __name__ == "__main__":
24
+ unittest.main()
python/jittor/test/test_unary_op.py
@@ -25,12 +25,12 @@ def test_unary_op(self):
25
assert jt.float64(1).data.dtype == "float64"
26
assert (jt.abs(-1) == 1).data.all()
27
assert (abs(-jt.float64(1)) == 1).data.all()
28
- a = [-1,2,3,0]
+ a = np.array([-1,2,3,0])
29
check("abs", a)
30
check("negative", a)
31
check("logical_not", a)
32
check("bitwise_not", a)
- b = [1.1, 2.2, 3.3, 4.4, -1, 0]
+ b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0])
check("log", a)
check("exp", a)
check("sqrt", a)
@@ -42,7 +42,7 @@ def test_grad(self):
42
"cos", "arccos", "cosh", "arccosh",
43
"sigmoid",
44
]
45
- a = [1.1, 2.2, 3.3, 4.4]
+ a = np.array([1.1, 2.2, 3.3, 4.4])
46
for op in ops:
47
if op == "abs":
48
b = np.array(a+[-1,])
src/misc/nano_string.h
@@ -156,7 +156,7 @@ NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) {
156
if (dsize==8) return ns_int64;
157
if (dsize==4) return ns_int32;
158
if (dsize==2) return ns_int16;
159
- return ns_int8;
+ return v1;
160
}
161
162
0 commit comments