1+ # ***************************************************************
2+ # Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
3+ # This file is subject to the terms and conditions defined in
4+ # file 'LICENSE.txt', which is part of this source code package.
5+ # ***************************************************************
6+ import unittest
7+ import jittor as jt
8+ import numpy as np
9+ from jittor import Module
10+ from jittor .models import resnet
11+ import pickle
12+
13+ f32 = jt .float32
14+
15+ def matmul (a , b ):
16+ (n , m ), k = a .shape , b .shape [- 1 ]
17+ a = a .broadcast ([n ,m ,k ], dims = [2 ])
18+ b = b .broadcast ([n ,m ,k ], dims = [0 ])
19+ return (a * b ).sum (dim = 1 )
20+
21+
22+ def relu (x ):
23+ return jt .maximum (x , 0.0 )
24+ Relu = jt .make_module (relu )
25+
26+ class Model (Module ):
27+ def __init__ (self , input_size ):
28+ self .linear1 = Linear (input_size , 10 )
29+ self .relu1 = Relu ()
30+ self .linear2 = Linear (10 , 1 )
31+ def execute (self , x ):
32+ x = self .linear1 (x )
33+ x = self .relu1 (x )
34+ return self .linear2 (x )
35+
36+ class Linear (Module ):
37+ def __init__ (self , in_features , out_features , bias = True ):
38+ self .w = (jt .random ((in_features , out_features ))- 0.5 ) / in_features ** 0.5
39+ self .b = jt .random ((out_features ,))- 0.5 if bias else None
40+ def execute (self , x ):
41+ x = matmul (x , self .w )
42+ if self .b is not None :
43+ return x + self .b
44+ return x
45+
46+
47+ class TestTraceVar (unittest .TestCase ):
48+ def test_simple_model (self ):
49+ with jt .flag_scope (trace_py_var = 2 ):
50+
51+ model = Model (input_size = 1 )
52+ batch_size = 10
53+ x = jt .float32 (np .random .rand (batch_size , 1 ))
54+ y = model (x )
55+ y .sync ()
56+
57+
58+ data = jt .dump_trace_data ()
59+ jt .clear_trace_data ()
60+ # with open("/tmp/simple_model.pkl", "wb") as f:
61+ # pickle.dump(data, f)
62+
63+ def test_simple_model_train (self ):
64+ with jt .flag_scope (trace_py_var = 2 ):
65+
66+ model = Model (input_size = 1 )
67+ opt = jt .optim .SGD (model .parameters (), 0.1 )
68+
69+ batch_size = 10
70+ x = jt .float32 (np .random .rand (batch_size , 1 ))
71+ y = model (x )
72+ opt .step (y ** 2 )
73+ jt .sync_all ()
74+
75+ data = jt .dump_trace_data ()
76+ jt .clear_trace_data ()
77+ # with open("/tmp/simple_model_train.pkl", "wb") as f:
78+ # pickle.dump(data, f)
79+
80+ def test_resnet (self ):
81+ with jt .flag_scope (trace_py_var = 2 ):
82+
83+ resnet18 = resnet .Resnet18 ()
84+ x = jt .float32 (np .random .rand (2 , 3 , 224 , 224 ))
85+ y = resnet18 (x )
86+ y .sync ()
87+
88+ data = jt .dump_trace_data ()
89+ jt .clear_trace_data ()
90+ # with open("/tmp/resnet.pkl", "wb") as f:
91+ # pickle.dump(data, f)
92+
93+ def test_resnet_train (self ):
94+ with jt .flag_scope (trace_py_var = 2 ):
95+
96+ resnet18 = resnet .Resnet18 ()
97+ opt = jt .optim .SGD (resnet18 .parameters (), 0.1 )
98+ x = jt .float32 (np .random .rand (2 , 3 , 224 , 224 ))
99+ y = resnet18 (x )
100+
101+ opt .step (y ** 2 )
102+ jt .sync_all ()
103+
104+ data = jt .dump_trace_data ()
105+ jt .clear_trace_data ()
106+ # with open("/tmp/resnet_train.pkl", "wb") as f:
107+ # pickle.dump(data, f)
108+
109+ def test_resnet_train_profile (self ):
110+ with jt .profile_scope (trace_py_var = 1 ):
111+
112+ resnet18 = resnet .Resnet18 ()
113+ opt = jt .optim .SGD (resnet18 .parameters (), 0.1 )
114+ x = jt .float32 (np .random .rand (2 , 3 , 224 , 224 ))
115+ y = resnet18 (x )
116+
117+ opt .step (y ** 2 )
118+ jt .sync_all ()
119+
120+
121+ if __name__ == "__main__" :
122+ unittest .main ()
0 commit comments