Skip to content

Commit 805f483

Browse files
committed
better trace_py_var
1 parent 5713130 commit 805f483

18 files changed

Lines changed: 588 additions & 102 deletions

python/jittor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.2.1.0'
10+
__version__ = '1.2.1.1'
1111
from . import lock
1212
with lock.lock_scope():
1313
ori_int = int

python/jittor/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,10 +1193,10 @@ def dfs(self, parents, k, callback, callback_leave):
11931193
ret = callback(parents, k, self, n_children)
11941194
if ret == False:
11951195
return
1196+
parents.append(self)
11961197
for k,v in self.layers.items():
1197-
parents.append(self)
11981198
v.dfs(parents, k, callback, callback_leave)
1199-
parents.pop()
1199+
parents.pop()
12001200
if callback_leave:
12011201
callback_leave(parents, k, self, n_children)
12021202
def append(self, mod):
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# ***************************************************************
2+
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
3+
# This file is subject to the terms and conditions defined in
4+
# file 'LICENSE.txt', which is part of this source code package.
5+
# ***************************************************************
6+
import unittest
7+
import jittor as jt
8+
import numpy as np
9+
from jittor import Module
10+
from jittor.models import resnet
11+
import pickle
12+
13+
f32 = jt.float32
14+
15+
def matmul(a, b):
16+
(n, m), k = a.shape, b.shape[-1]
17+
a = a.broadcast([n,m,k], dims=[2])
18+
b = b.broadcast([n,m,k], dims=[0])
19+
return (a*b).sum(dim=1)
20+
21+
22+
def relu(x):
23+
return jt.maximum(x, 0.0)
24+
Relu = jt.make_module(relu)
25+
26+
class Model(Module):
27+
def __init__(self, input_size):
28+
self.linear1 = Linear(input_size, 10)
29+
self.relu1 = Relu()
30+
self.linear2 = Linear(10, 1)
31+
def execute(self, x):
32+
x = self.linear1(x)
33+
x = self.relu1(x)
34+
return self.linear2(x)
35+
36+
class Linear(Module):
37+
def __init__(self, in_features, out_features, bias=True):
38+
self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5
39+
self.b = jt.random((out_features,))-0.5 if bias else None
40+
def execute(self, x):
41+
x = matmul(x, self.w)
42+
if self.b is not None:
43+
return x+self.b
44+
return x
45+
46+
47+
class TestTraceVar(unittest.TestCase):
48+
def test_simple_model(self):
49+
with jt.flag_scope(trace_py_var=2):
50+
51+
model = Model(input_size=1)
52+
batch_size = 10
53+
x = jt.float32(np.random.rand(batch_size, 1))
54+
y = model(x)
55+
y.sync()
56+
57+
58+
data = jt.dump_trace_data()
59+
jt.clear_trace_data()
60+
# with open("/tmp/simple_model.pkl", "wb") as f:
61+
# pickle.dump(data, f)
62+
63+
def test_simple_model_train(self):
64+
with jt.flag_scope(trace_py_var=2):
65+
66+
model = Model(input_size=1)
67+
opt = jt.optim.SGD(model.parameters(), 0.1)
68+
69+
batch_size = 10
70+
x = jt.float32(np.random.rand(batch_size, 1))
71+
y = model(x)
72+
opt.step(y**2)
73+
jt.sync_all()
74+
75+
data = jt.dump_trace_data()
76+
jt.clear_trace_data()
77+
# with open("/tmp/simple_model_train.pkl", "wb") as f:
78+
# pickle.dump(data, f)
79+
80+
def test_resnet(self):
81+
with jt.flag_scope(trace_py_var=2):
82+
83+
resnet18 = resnet.Resnet18()
84+
x = jt.float32(np.random.rand(2, 3, 224, 224))
85+
y = resnet18(x)
86+
y.sync()
87+
88+
data = jt.dump_trace_data()
89+
jt.clear_trace_data()
90+
# with open("/tmp/resnet.pkl", "wb") as f:
91+
# pickle.dump(data, f)
92+
93+
def test_resnet_train(self):
94+
with jt.flag_scope(trace_py_var=2):
95+
96+
resnet18 = resnet.Resnet18()
97+
opt = jt.optim.SGD(resnet18.parameters(), 0.1)
98+
x = jt.float32(np.random.rand(2, 3, 224, 224))
99+
y = resnet18(x)
100+
101+
opt.step(y**2)
102+
jt.sync_all()
103+
104+
data = jt.dump_trace_data()
105+
jt.clear_trace_data()
106+
# with open("/tmp/resnet_train.pkl", "wb") as f:
107+
# pickle.dump(data, f)
108+
109+
def test_resnet_train_profile(self):
110+
with jt.profile_scope(trace_py_var=1):
111+
112+
resnet18 = resnet.Resnet18()
113+
opt = jt.optim.SGD(resnet18.parameters(), 0.1)
114+
x = jt.float32(np.random.rand(2, 3, 224, 224))
115+
y = resnet18(x)
116+
117+
opt.step(y**2)
118+
jt.sync_all()
119+
120+
121+
if __name__ == "__main__":
122+
unittest.main()

python/jittor/utils/tracer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# ***************************************************************
2+
# Copyright (c) 2020 Jittor. Authors:
3+
# Dun Liang <randonlang@gmail.com>.
4+
#
5+
# All Rights Reserved.
6+
# This file is subject to the terms and conditions defined in
7+
# file 'LICENSE.txt', which is part of this source code package.
8+
# ***************************************************************
9+
import jittor as jt
10+
11+
def fill_module_name(m, name):
12+
ps = []
13+
stack = []
14+
def callback(parents, k, v, n):
15+
stack.append(str(k))
16+
for k2, p in v.__dict__.items():
17+
if isinstance(p, jt.Var):
18+
ps.append(p)
19+
p.name(".".join(stack[1:]+[str(k2)]))
20+
v._trace_name = str(k)
21+
def callback_leave(parents, k, v, n):
22+
stack.pop()
23+
m.dfs([], name, callback, callback_leave)

src/executor.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,10 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
426426
#endif
427427
last_is_cuda = is_cuda;
428428
op->do_run_after_prepare(jkl);
429+
// record trace data
430+
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var==2)) {
431+
trace_data.record_execution(op, is_fused_op, jkl);
432+
}
429433
LOGvvv << "Finished Op(" >> op->name() << rid >>
430434
"/" >> queue.size() >> ") output:" << op->outputs();
431435
if (is_fused_op) {
@@ -458,7 +462,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
458462
display_memory_info(__FILELINE__, false, true);
459463
// log jit_key and file location
460464
op->do_prepare(jkl);
461-
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
465+
string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc");
462466
LOGe << "[Error] source file location:" << jit_src_path;
463467
if (is_fused_op) {
464468
LOGf << "Execute fused operator(" >> rid >> '/' >> queue.size() >> ")"

src/grad.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
8484
if (grads.size()) {
8585
grads[0] = make_number(1.f, loss);
8686
assign_attrs(grads[0].ptr, loss);
87-
registe_node_trace_grad(grads[0].ptr, loss, 0);
8887
}
8988

9089
vector<pair<Node*, int64>> id_buffer;
@@ -154,6 +153,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
154153
} else
155154
douts[i] = nullptr;
156155
}
156+
trace_grad_op = op;
157157
op->grads(douts, dins);
158158
// dump "for (Var* in : op->inputs())"
159159
for (int i=0; i<n_i; i++,j++) {
@@ -175,8 +175,8 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
175175
auto out = id_buffer[j].first->var();
176176
if (id<0) continue;
177177
Var* dout = grads[id];
178+
trace_grad_op = op;
178179
VarPtr dvar = make_grad(op, out, dout, var, index);
179-
registe_node_trace_grad(dvar.ptr, op, index);
180180
if (dvar && dvar->num>=0 && var->num)
181181
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
182182
<< "dvar" << dvar << "var" << var;
@@ -194,12 +194,12 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
194194
}
195195
#endif
196196
assign_attrs(grad.ptr, var);
197-
registe_node_trace_grad(grad.ptr, var, index);
198197
}
199198
}
200199
}
201200
}
202201
}
202+
trace_grad_op = nullptr;
203203
// set zero grad
204204
for (size_t i=0; i<results.size(); i++) {
205205
Var* var = targets[i];
@@ -211,7 +211,6 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
211211
LOGw << "grads[">>i>>"] '">> var->name>>"' doesn't have gradient. It will be set to zero:" << var;
212212
grad = make_number(0.f, var);
213213
assign_attrs(grad.ptr, var);
214-
registe_node_trace_grad(grad.ptr, var, 0);
215214
}
216215
}
217216
return results;

src/misc/ring_buffer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ struct RingBuffer {
5050
inline ~Cond() {
5151
// a dirty hack
5252
// ref: https://stackoverflow.com/questions/20439404/pthread-conditions-and-process-termination
53-
cv.__data.__wrefs = 0;
53+
// cv.__data.__wrefs = 0;
54+
cv.__data = {0};
5455
pthread_cond_destroy(&cv);
5556
}
5657

src/node.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,15 @@ struct Node {
120120
#ifdef NODE_MEMCHECK
121121
inline Node() {
122122
lived_nodes[(void*)this] = ++total_node;
123-
registe_node_trace(this);
124123
}
125124

126125
inline virtual ~Node() {
127126
lived_nodes.erase((void*)this);
128-
unregiste_node_trace(this);
127+
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);
129128
}
130129
#else
131130
inline Node() {};
132-
inline virtual ~Node() {};
131+
inline virtual ~Node() { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);};
133132
#endif
134133
inline Var* var() { return (Var*)this; }
135134
inline Op* op() { return (Op*)this; }

src/op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Op::Op() {
3030
flags.set(NodeFlags::_var, 0);
3131
flags.set(NodeFlags::_cpu, 1);
3232
number_of_lived_ops++;
33+
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this);
3334
}
3435

3536
Op::~Op() {

src/parallel_compiler.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
125125
}
126126
LOGvvv << "Check op needs compile:" << op;
127127
op->do_prepare(jkl);
128-
if (jk.empty()) continue;
128+
if (jkl.empty()) continue;
129129

130-
const char* jit_key = jk.to_cstring();
130+
const char* jit_key = jkl.to_cstring();
131131
auto iter = jit_key_mapper.find(jit_key);
132132
if (iter != jit_key_mapper.end()) continue;
133133

0 commit comments

Comments
 (0)