Skip to content

Commit 7b2f154

Browse files
authored
Merge pull request #320 from ztqakita/master
Add ResNet example
2 parents 7b40366 + c00bec5 commit 7b2f154

File tree

2 files changed

+330
-0
lines changed

2 files changed

+330
-0
lines changed

brainpy/_src/dyn/layers/linear.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
__all__ = [
1616
'Dense',
17+
'Identity',
1718
]
1819

1920

@@ -185,3 +186,13 @@ def offline_fit(self,
185186
self.W.value = Wff
186187
self.b.value = bias[0]
187188

189+
190+
class Identity(Layer):
191+
r"""A placeholder identity operator that is argument-insensitive.
192+
"""
193+
194+
def __init__(self, *args, **kwargs) -> None:
195+
super(Identity, self).__init__(*args, **kwargs)
196+
197+
def update(self, sha, x):
198+
return x
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import argparse
4+
import os
5+
import sys
6+
import time
7+
from functools import partial
8+
9+
import brainpy_datasets as bd
10+
11+
import brainpy as bp
12+
import brainpy.math as bm
13+
import jax.numpy as jnp
14+
15+
bm.set_environment(mode=bm.training_mode, dt=1.)
16+
17+
18+
class BasicBlock(bp.DynamicalSystem):
19+
expansion = 1
20+
21+
def __init__(self, in_planes, planes, stride=1, is_last=False):
22+
super(BasicBlock, self).__init__()
23+
self.is_last = is_last
24+
self.conv1 = bp.layers.Conv2D(in_planes, planes, kernel_size=(3, 3), strides=stride, padding=(1, 1),
25+
w_initializer=bp.init.KaimingNormal(mode='fan_out'))
26+
self.bn1 = bp.layers.BatchNorm2D(planes)
27+
self.conv2 = bp.layers.Conv2D(planes, planes, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1),
28+
w_initializer=bp.init.KaimingNormal(mode='fan_out'))
29+
self.bn2 = bp.layers.BatchNorm2D(planes)
30+
31+
# self.shortcut = bp.layers.Identity()
32+
self.shortcut = bp.Sequential()
33+
if stride != 1 or in_planes != self.expansion * planes:
34+
self.shortcut = bp.Sequential(
35+
bp.layers.Conv2D(in_planes, self.expansion * planes, kernel_size=1, strides=stride,
36+
w_initializer=bp.init.KaimingNormal(mode='fan_out')),
37+
bp.layers.BatchNorm2D(self.expansion * planes)
38+
)
39+
40+
def update(self, s, x):
41+
out = bm.relu(self.bn1(s, self.conv1(s, x)))
42+
out = self.bn2(s, self.conv2(s, out))
43+
out += self.shortcut(s, x)
44+
preact = out
45+
out = bm.relu(out)
46+
if self.is_last:
47+
return out, preact
48+
else:
49+
return out
50+
51+
52+
class Bottleneck(bp.DynamicalSystem):
53+
expansion = 4
54+
55+
def __init__(self, in_planes, planes, stride=1, is_last=False):
56+
super(Bottleneck, self).__init__()
57+
self.is_last = is_last
58+
self.conv1 = bp.layers.Conv2D(in_planes, planes, kernel_size=(1, 1),
59+
w_initializer=bp.init.KaimingNormal(mode='fan_out'))
60+
self.bn1 = bp.layers.BatchNorm2D(planes)
61+
self.conv2 = bp.layers.Conv2D(planes, planes, kernel_size=(3, 3), strides=stride, padding=(1, 1),
62+
w_initializer=bp.init.KaimingNormal(mode='fan_out'))
63+
self.bn2 = bp.layers.BatchNorm2D(planes)
64+
self.conv3 = bp.layers.Conv2D(planes, self.expansion * planes, kernel_size=(1, 1),
65+
w_initializer=bp.init.KaimingNormal(mode='fan_out'))
66+
self.bn3 = bp.layers.BatchNorm2D(self.expansion * planes)
67+
68+
# self.shortcut = bp.layers.Identity()
69+
self.shortcut = bp.Sequential()
70+
if stride != 1 or in_planes != self.expansion * planes:
71+
self.shortcut = bp.Sequential(
72+
bp.layers.Conv2D(in_planes, self.expansion * planes, kernel_size=1, strides=stride,
73+
w_initializer=bp.init.KaimingNormal(mode='fan_out')),
74+
bp.layers.BatchNorm2D(self.expansion * planes)
75+
)
76+
77+
def update(self, s, x):
78+
out = bm.relu(self.bn1(s, self.conv1(s, x)))
79+
out = bm.relu(self.bn2(s, self.conv2(s, out)))
80+
out = self.bn3(s, self.conv3(s, out))
81+
out += self.shortcut(s, x)
82+
preact = out
83+
out = bm.relu(out)
84+
if self.is_last:
85+
return out, preact
86+
else:
87+
return out
88+
89+
90+
class ResNet(bp.DynamicalSystem):
91+
def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False):
92+
super(ResNet, self).__init__()
93+
self.in_planes = 64
94+
95+
self.conv1 = bp.layers.Conv2D(3, 64, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1),
96+
w_initializer=bp.init.KaimingNormal(mode='fan_out'))
97+
self.bn1 = bp.layers.BatchNorm2D(64)
98+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
99+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
100+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
101+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
102+
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
103+
self.linear = bp.layers.Dense(512 * block.expansion, num_classes)
104+
105+
# Zero-initialize the last BN in each residual branch,
106+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
107+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
108+
if zero_init_residual:
109+
for m in self.nodes():
110+
if isinstance(m, Bottleneck):
111+
# nn.init.constant_(m.bn3.weight, 0)
112+
m.bn3.scale[:] = 0
113+
elif isinstance(m, BasicBlock):
114+
m.bn2.scale[:] = 0
115+
116+
def get_bn_before_relu(self):
117+
if isinstance(self.layer1[0], Bottleneck):
118+
bn1 = self.layer1[-1].bn3
119+
bn2 = self.layer2[-1].bn3
120+
bn3 = self.layer3[-1].bn3
121+
bn4 = self.layer4[-1].bn3
122+
elif isinstance(self.layer1[0], BasicBlock):
123+
bn1 = self.layer1[-1].bn2
124+
bn2 = self.layer2[-1].bn2
125+
bn3 = self.layer3[-1].bn2
126+
bn4 = self.layer4[-1].bn2
127+
else:
128+
raise NotImplementedError('ResNet unknown block error !!!')
129+
130+
return [bn1, bn2, bn3, bn4]
131+
132+
def _make_layer(self, block, planes, num_blocks, stride):
133+
strides = [stride] + [1] * (num_blocks - 1)
134+
layers = []
135+
for i in range(num_blocks):
136+
stride = strides[i]
137+
layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1))
138+
self.in_planes = planes * block.expansion
139+
return bp.Sequential(*layers)
140+
141+
def update(self, s, x, is_feat=False, preact=False):
142+
out = bm.relu(self.bn1(s, self.conv1(s, x)))
143+
f0 = out
144+
out, f1_pre = self.layer1(s, out)
145+
f1 = out
146+
out, f2_pre = self.layer2(s, out)
147+
f2 = out
148+
out, f3_pre = self.layer3(s, out)
149+
f3 = out
150+
out, f4_pre = self.layer4(s, out)
151+
f4 = out
152+
# out = self.avgpool(s, out)
153+
# out = out.reshape(128, -1)
154+
out = bm.mean(out, axis=(1, 2))
155+
f5 = out
156+
out = self.linear(s, out)
157+
if is_feat:
158+
if preact:
159+
return [[f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], out]
160+
else:
161+
return [f0, f1, f2, f3, f4, f5], out
162+
else:
163+
return out
164+
165+
166+
def ResNet18(**kwargs):
167+
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
168+
169+
170+
def ResNet34(**kwargs):
171+
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
172+
173+
174+
def ResNet50(**kwargs):
175+
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
176+
177+
178+
def ResNet101(**kwargs):
179+
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
180+
181+
182+
def ResNet152(**kwargs):
183+
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
184+
185+
186+
def main():
187+
parser = argparse.ArgumentParser(description='Classify Fashion-MNIST')
188+
parser.add_argument('-platform', default='cpu', help='platform')
189+
parser.add_argument('-batch', default=128, type=int, help='batch size')
190+
parser.add_argument('-n_epoch', default=64, type=int, metavar='N', help='number of total epochs to run')
191+
parser.add_argument('-data-dir', default='./data', type=str, help='root dir of Fashion-MNIST dataset')
192+
parser.add_argument('-out-dir', default='./logs', type=str, help='root dir for saving logs and checkpoint')
193+
parser.add_argument('-lr', default=0.1, type=float, help='learning rate')
194+
args = parser.parse_args()
195+
print(args)
196+
197+
bm.set_platform(args.platform)
198+
199+
out_dir = os.path.join(args.out_dir, f'b{args.batch}_lr{args.lr}_epoch{args.n_epoch}')
200+
201+
# dataset
202+
train_set = bd.vision.MNIST(root=args.data_dir, split='train', download=True)
203+
test_set = bd.vision.MNIST(root=args.data_dir, split='test', download=True)
204+
x_train = bm.asarray(train_set.data / 255, dtype=bm.float_).reshape((-1, 28, 28, 1))
205+
y_train = bm.asarray(train_set.targets, dtype=bm.int_)
206+
x_test = bm.asarray(test_set.data / 255, dtype=bm.float_).reshape((-1, 28, 28, 1))
207+
y_test = bm.asarray(test_set.targets, dtype=bm.int_)
208+
209+
with bm.training_environment():
210+
net = ResNet18(num_classes=10)
211+
212+
# loss function
213+
@bm.to_object(child_objs=net)
214+
def loss_fun(X, Y, fit=True):
215+
s = {'fit': fit}
216+
predictions = net(s, X)
217+
l = bp.losses.cross_entropy_loss(predictions, Y)
218+
n = bm.sum(predictions.argmax(1) == Y)
219+
return l, n
220+
221+
grad_fun = bm.grad(loss_fun, grad_vars=net.train_vars().unique(), has_aux=True, return_value=True)
222+
223+
# optimizer
224+
optimizer = bp.optim.Adam(bp.optim.ExponentialDecay(args.lr, 1, 0.9999),
225+
train_vars=net.train_vars().unique())
226+
227+
@bm.jit
228+
@bm.to_object(child_objs=(grad_fun, optimizer))
229+
def train_fun(X, Y):
230+
grads, l, n = grad_fun(X, Y)
231+
optimizer.update(grads)
232+
return l, n
233+
234+
predict_loss_fun = bm.jit(partial(loss_fun, fit=False), child_objs=loss_fun)
235+
236+
os.makedirs(out_dir, exist_ok=True)
237+
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
238+
args_txt.write(str(args))
239+
args_txt.write('\n')
240+
args_txt.write(' '.join(sys.argv))
241+
242+
max_test_acc = -1
243+
for epoch_i in range(0, args.n_epoch):
244+
start_time = time.time()
245+
loss, train_acc = [], 0.
246+
for i in range(0, x_train.shape[0], args.batch):
247+
xs = x_train[i: i + args.batch]
248+
ys = y_train[i: i + args.batch]
249+
l, n = train_fun(xs, ys)
250+
if (i / args.batch) % 100 == 0:
251+
print(f'Epoch {epoch_i}: Train {i} batch, loss = {bm.mean(l):.4f}')
252+
loss.append(l)
253+
train_acc += n
254+
train_acc /= x_train.shape[0]
255+
train_loss = bm.mean(bm.asarray(loss))
256+
257+
loss, test_acc = [], 0.
258+
for i in range(0, x_test.shape[0], args.batch):
259+
xs = x_test[i: i + args.batch]
260+
ys = y_test[i: i + args.batch]
261+
l, n = predict_loss_fun(xs, ys)
262+
loss.append(l)
263+
test_acc += n
264+
test_acc /= x_test.shape[0]
265+
test_loss = bm.mean(bm.asarray(loss))
266+
267+
t = time.time() - start_time
268+
print(f'epoch {epoch_i}, used {t:.3f} seconds, '
269+
f'train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, '
270+
f'test_loss = {test_loss:.4f}, test_acc = {test_acc:.4f}')
271+
272+
if max_test_acc < test_acc:
273+
max_test_acc = test_acc
274+
states = {
275+
'net': net.state_dict(),
276+
'optimizer': optimizer.state_dict(),
277+
'epoch_i': epoch_i,
278+
'train_acc': train_acc,
279+
'test_acc': test_acc,
280+
}
281+
bp.checkpoints.save(out_dir, states, epoch_i)
282+
283+
# inference
284+
state_dict = bp.checkpoints.load(out_dir)
285+
net.load_state_dict(state_dict['net'])
286+
correct_num = 0
287+
for i in range(0, x_test.shape[0], 512):
288+
xs = x_test[i: i + 512]
289+
ys = y_test[i: i + 512]
290+
correct_num += predict_loss_fun(xs, ys)[1]
291+
print('Max test accuracy: ', correct_num / x_test.shape[0])
292+
293+
294+
if __name__ == '__main__':
295+
main()
296+
# import time
297+
#
298+
# with bm.training_environment():
299+
# net = ResNet34()
300+
# x = bm.random.randn(2, 32, 32, 1)
301+
# start = time.time()
302+
# feats, logit = net({'fit': False}, x, is_feat=True, preact=True)
303+
# end = time.time()
304+
# print(f'time: {end - start}')
305+
#
306+
# start = time.time()
307+
# feats, logit = net({'fit': False}, x, is_feat=True, preact=True)
308+
# end = time.time()
309+
# print(f'time: {end - start}')
310+
#
311+
# for f in feats:
312+
# print(f.shape, f.min().item(), f.max().item())
313+
# print(logit.shape)
314+
#
315+
# for m in net.get_bn_before_relu():
316+
# if isinstance(m, bp.layers.BatchNorm2D):
317+
# print('pass')
318+
# else:
319+
# print('warning')

0 commit comments

Comments
 (0)