Skip to content

Commit f1a40bc

Browse files
committed
improve import speed
1 parent 5ce1a1d commit f1a40bc

5 files changed

Lines changed: 79 additions & 13 deletions

File tree

python/jittor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.1.7.11'
10+
__version__ = '1.1.7.12'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler

python/jittor/compiler.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -784,11 +784,16 @@ def try_find_exe(*args):
784784
def check_pybt(gdb_path, python_path):
785785
if gdb_path=='' or python_path=='':
786786
return False
787-
ret = sp.getoutput(f"{gdb_path} --batch {python_path} -ex 'help py-bt'")
788-
if 'python frame' in ret:
789-
LOG.v("py-bt found in gdb.")
790-
return True
791-
return False
787+
return True
788+
# TODO: prev we use below code to check has py-bt or nor
789+
# but it is too slow, so we comment it,
790+
# find a better way to check py-bt exist
791+
792+
# ret = sp.getoutput(f"{gdb_path} --batch {python_path} -ex 'help py-bt'")
793+
# if 'python frame' in ret:
794+
# LOG.v("py-bt found in gdb.")
795+
# return True
796+
# return False
792797

793798
def check_debug_flags():
794799
global is_debug

python/jittor/pyjt_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def find_bc(i):
449449
continue
450450
else:
451451
defs.append(def_info)
452-
LOG.vvv(json.dumps(def_info, indent=4))
452+
LOG.vvv(lambda: json.dumps(def_info, indent=4))
453453
# deal with defs
454454
if len(defs) == 0: return
455455
# include_name = h[4:] # remove "src/" prefix

python/jittor/test/test_init.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# ***************************************************************
2+
# Copyright (c) Jittor 2020, Author:
3+
# All Rights Reserved.
4+
# This file is subject to the terms and conditions defined in
5+
# file 'LICENSE.txt', which is part of this source code package.
6+
# ***************************************************************
7+
import jittor as jt
8+
import unittest
9+
import numpy as np
10+
from jittor import models
11+
12+
pass_this_test = False
13+
try:
14+
jt.dirty_fix_pytorch_runtime_error()
15+
import torch
16+
import torchvision
17+
except Exception as e:
18+
pass_this_test = True
19+
20+
def get_error(a, b):
21+
return np.abs(a-b) / max(np.abs(a), np.abs(b), 1e-5) , np.abs(a-b)
22+
23+
def check(jt_mod, torch_mod, rtol=1e-2, atol=1e-5, mean_atol=1e-5):
24+
pa = [ p for p in jt_mod.parameters() if not p.is_stop_grad() ]
25+
pb = list(torch_mod.parameters())
26+
assert len(pa) == len(pb)
27+
error_count = 0
28+
for a,b in zip(pa, pb):
29+
assert a.shape == list(b.shape), (a.shape, b.shape, a.name())
30+
stda, meana = np.std(a.numpy()), np.mean(a.numpy())
31+
stdb, meanb = np.std(b.detach().numpy()), np.mean(b.detach().numpy())
32+
33+
r_err, a_err = get_error(stda, stdb)
34+
if r_err > rtol and a_err > atol:
35+
error_count += 1
36+
print("compare std error", stda, stdb, r_err, a_err, a.name(), a.shape)
37+
38+
r_err, a_err = get_error(meana, meanb)
39+
if r_err > rtol and a_err > mean_atol:
40+
error_count += 1
41+
print("compare mean error", meana, meanb, r_err, a_err, a.name(), a.shape)
42+
assert error_count == 0
43+
44+
@unittest.skipIf(pass_this_test, f"pass init check, no torch found")
45+
class TestInit(unittest.TestCase):
46+
@classmethod
47+
def setUpClass(self):
48+
jt.seed(0)
49+
np.random.seed(0)
50+
torch.manual_seed(0)
51+
52+
def test_conv(self):
53+
check(jt.nn.Conv(64, 256, 3), torch.nn.Conv2d(64, 256, 3), rtol=1e-1, mean_atol=1e-3)
54+
55+
def test_resnet(self):
56+
check(models.resnet152(), torchvision.models.resnet152(), rtol=2e-2, mean_atol=1e-2)
57+
58+
if __name__ == "__main__":
59+
unittest.main()

python/jittor_utils/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,20 @@ def log_capture_read(self):
3131
return cc.log_capture_read()
3232

3333
def _log(self, level, verbose, *msg):
34-
if len(msg):
35-
msg = " ".join([ str(m) for m in msg ])
36-
else:
37-
msg = str(msg)
34+
if self.log_silent or verbose > self.log_v:
35+
return
36+
ss = ""
37+
for m in msg:
38+
if callable(m):
39+
m = m()
40+
ss += str(m)
41+
msg = ss
3842
f = inspect.currentframe()
3943
fileline = inspect.getframeinfo(f.f_back.f_back)
4044
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
4145
if cc and hasattr(cc, "log"):
4246
cc.log(fileline, level, verbose, msg)
4347
else:
44-
if self.log_silent or verbose > self.log_v:
45-
return
4648
time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f")
4749
tid = threading.get_ident()%100
4850
v = f" v{verbose}" if verbose else ""

0 commit comments

Comments
 (0)