Skip to content

Commit eae3fff

Browse files
committed
add eager execution
1 parent a62b45d commit eae3fff

6 files changed

Lines changed: 50 additions & 7 deletions

File tree

python/jittor/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.1.6.8'
10+
__version__ = '1.1.7.0'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler
@@ -284,6 +284,9 @@ def detach_inplace(x):
284284
return x.swap(x.stop_grad().clone())
285285
Var.start_grad = Var.detach_inplace = detach_inplace
286286

287+
def detach(x):
288+
return x.detach()
289+
287290
def unsqueeze(x, dim):
288291
shape = list(x.shape)
289292
if dim < 0: dim += len(shape) + 1
@@ -623,12 +626,17 @@ def grad(self, grad0, grad1):
623626
624627
'''
625628
def __call__(self, *args):
629+
backup = args
626630
args = list(args)
627631
taped_inputs = []
628632
taped_outputs = []
629633
input_mask = [-1] * len(args)
630634
for i,v in enumerate(args):
631635
if isinstance(v, Var):
636+
if v.is_stop_grad():
637+
# -2 in input_mask represents it is stop_grad
638+
input_mask[i] = -2
639+
continue
632640
v = v.tape()
633641
input_mask[i] = len(taped_inputs)
634642
args[i] = v
@@ -664,7 +672,8 @@ def _grad(self, *args):
664672
for i, r in enumerate(ret):
665673
j = self.input_mask[i]
666674
if j<0:
667-
assert r is None, f"{type(self)}'s {i}-th returned grad should be None, "\
675+
# -2 in input_mask represents it is stop_grad
676+
assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
668677
"because the input value is not jittor variable."
669678
else:
670679
new_ret.append(r)

python/jittor/test/test_clone.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,14 @@ def test(self):
2121
c.stop_grad()
2222
assert jt.number_of_lived_vars()==3
2323

24+
def test2(self):
25+
a = jt.array([1,2])
26+
print(a.detach())
27+
28+
@jt.flag_scope(eager_execution=1)
29+
def test3(self):
30+
a = jt.array([1,2])
31+
print(a.detach())
32+
2433
if __name__ == "__main__":
2534
unittest.main()

python/jittor/test/test_function.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,5 +280,14 @@ def grad(self, grad0, _, grad1):
280280
assert t3 < t2 + 10, (t1,t2,t3)
281281
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
282282

283+
284+
class TestFunctionWithEagerExecution(TestFunction):
285+
@classmethod
286+
def setUpClass(self):
287+
jt.flags.eager_execution = 1
288+
@classmethod
289+
def tearDownClass(self):
290+
jt.flags.eager_execution = 0
291+
283292
if __name__ == "__main__":
284293
unittest.main()

python/jittor/version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
b27082f9444a4e627f7dfc574d0114302ba27b5e
1+
a62b45d6caf9c1c18a9118630ec8a591c576e635

src/fuser.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ void count_fuse(int64_t tt, int start_var_num, const vector<Op*>& ops, const vec
146146
for (uint i=0; i<ops.size(); i++)
147147
LOGvvvv << ops[i] << dis[i] << deps[i];
148148
}
149-
150149
for (uint i=0; i<vars.size(); i++) {
151150
Var* v = vars[i];
152151
if (!v || v->tflag!=tt) {

src/var_holder.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,41 @@
1515
#include "update_queue.h"
1616

1717
namespace jittor {
18-
18+
19+
DEFINE_FLAG(int, eager_execution, 0, "Use Eager execution rather than lazy execution, This flag makes error message and traceback infomation better.");
20+
1921
list<VarHolder*> VarHolder::hold_vars;
2022

2123
void add_hold_vars(VarHolder* self) {
2224
VarHolder::hold_vars.push_front(self);
2325
self->iter = VarHolder::hold_vars.begin();
26+
if (!eager_execution) return;
27+
auto v = self->var;
28+
for (int i=0; i<5; i++) {
29+
auto op = v->input();
30+
if (!op) break;
31+
if (i==0 && op->name() == string("tape")) return;
32+
if (op->type() == OpType::other) break;
33+
if (op->type() == OpType::reduce) break;
34+
if (op->inputs().size() == 0)
35+
break;
36+
if (op->type() == OpType::broadcast)
37+
return;
38+
v = op->inputs().front();
39+
}
40+
self->sync(true);
2441
}
2542

2643
VarHolder::VarHolder(Var* v) : var(v) {
27-
add_hold_vars(this);
2844
// Var holder has both forward and backward liveness
2945
var->own_both_liveness();
46+
add_hold_vars(this);
3047
}
3148

3249
VarHolder::VarHolder(VarPtr&& v) {
33-
add_hold_vars(this);
3450
var = v.ptr;
3551
v.ptr = nullptr;
52+
add_hold_vars(this);
3653
}
3754

3855
VarHolder::VarHolder(VarHolder* v) : var(v->var) {

0 commit comments

Comments
 (0)