|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +import argparse |
| 4 | +import os |
| 5 | +import sys |
| 6 | +import time |
| 7 | +from functools import partial |
| 8 | + |
| 9 | +import brainpy_datasets as bd |
| 10 | + |
| 11 | +import brainpy as bp |
| 12 | +import brainpy.math as bm |
| 13 | +import jax.numpy as jnp |
| 14 | + |
| 15 | +bm.set_environment(mode=bm.training_mode, dt=1.) |
| 16 | + |
| 17 | + |
| 18 | +class BasicBlock(bp.DynamicalSystem): |
| 19 | + expansion = 1 |
| 20 | + |
| 21 | + def __init__(self, in_planes, planes, stride=1, is_last=False): |
| 22 | + super(BasicBlock, self).__init__() |
| 23 | + self.is_last = is_last |
| 24 | + self.conv1 = bp.layers.Conv2D(in_planes, planes, kernel_size=(3, 3), strides=stride, padding=(1, 1), |
| 25 | + w_initializer=bp.init.KaimingNormal(mode='fan_out')) |
| 26 | + self.bn1 = bp.layers.BatchNorm2D(planes) |
| 27 | + self.conv2 = bp.layers.Conv2D(planes, planes, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), |
| 28 | + w_initializer=bp.init.KaimingNormal(mode='fan_out')) |
| 29 | + self.bn2 = bp.layers.BatchNorm2D(planes) |
| 30 | + |
| 31 | + # self.shortcut = bp.layers.Identity() |
| 32 | + self.shortcut = bp.Sequential() |
| 33 | + if stride != 1 or in_planes != self.expansion * planes: |
| 34 | + self.shortcut = bp.Sequential( |
| 35 | + bp.layers.Conv2D(in_planes, self.expansion * planes, kernel_size=1, strides=stride, |
| 36 | + w_initializer=bp.init.KaimingNormal(mode='fan_out')), |
| 37 | + bp.layers.BatchNorm2D(self.expansion * planes) |
| 38 | + ) |
| 39 | + |
| 40 | + def update(self, s, x): |
| 41 | + out = bm.relu(self.bn1(s, self.conv1(s, x))) |
| 42 | + out = self.bn2(s, self.conv2(s, out)) |
| 43 | + out += self.shortcut(s, x) |
| 44 | + preact = out |
| 45 | + out = bm.relu(out) |
| 46 | + if self.is_last: |
| 47 | + return out, preact |
| 48 | + else: |
| 49 | + return out |
| 50 | + |
| 51 | + |
| 52 | +class Bottleneck(bp.DynamicalSystem): |
| 53 | + expansion = 4 |
| 54 | + |
| 55 | + def __init__(self, in_planes, planes, stride=1, is_last=False): |
| 56 | + super(Bottleneck, self).__init__() |
| 57 | + self.is_last = is_last |
| 58 | + self.conv1 = bp.layers.Conv2D(in_planes, planes, kernel_size=(1, 1), |
| 59 | + w_initializer=bp.init.KaimingNormal(mode='fan_out')) |
| 60 | + self.bn1 = bp.layers.BatchNorm2D(planes) |
| 61 | + self.conv2 = bp.layers.Conv2D(planes, planes, kernel_size=(3, 3), strides=stride, padding=(1, 1), |
| 62 | + w_initializer=bp.init.KaimingNormal(mode='fan_out')) |
| 63 | + self.bn2 = bp.layers.BatchNorm2D(planes) |
| 64 | + self.conv3 = bp.layers.Conv2D(planes, self.expansion * planes, kernel_size=(1, 1), |
| 65 | + w_initializer=bp.init.KaimingNormal(mode='fan_out')) |
| 66 | + self.bn3 = bp.layers.BatchNorm2D(self.expansion * planes) |
| 67 | + |
| 68 | + # self.shortcut = bp.layers.Identity() |
| 69 | + self.shortcut = bp.Sequential() |
| 70 | + if stride != 1 or in_planes != self.expansion * planes: |
| 71 | + self.shortcut = bp.Sequential( |
| 72 | + bp.layers.Conv2D(in_planes, self.expansion * planes, kernel_size=1, strides=stride, |
| 73 | + w_initializer=bp.init.KaimingNormal(mode='fan_out')), |
| 74 | + bp.layers.BatchNorm2D(self.expansion * planes) |
| 75 | + ) |
| 76 | + |
| 77 | + def update(self, s, x): |
| 78 | + out = bm.relu(self.bn1(s, self.conv1(s, x))) |
| 79 | + out = bm.relu(self.bn2(s, self.conv2(s, out))) |
| 80 | + out = self.bn3(s, self.conv3(s, out)) |
| 81 | + out += self.shortcut(s, x) |
| 82 | + preact = out |
| 83 | + out = bm.relu(out) |
| 84 | + if self.is_last: |
| 85 | + return out, preact |
| 86 | + else: |
| 87 | + return out |
| 88 | + |
| 89 | + |
| 90 | +class ResNet(bp.DynamicalSystem): |
| 91 | + def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): |
| 92 | + super(ResNet, self).__init__() |
| 93 | + self.in_planes = 64 |
| 94 | + |
| 95 | + self.conv1 = bp.layers.Conv2D(3, 64, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), |
| 96 | + w_initializer=bp.init.KaimingNormal(mode='fan_out')) |
| 97 | + self.bn1 = bp.layers.BatchNorm2D(64) |
| 98 | + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) |
| 99 | + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) |
| 100 | + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) |
| 101 | + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) |
| 102 | + # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| 103 | + self.linear = bp.layers.Dense(512 * block.expansion, num_classes) |
| 104 | + |
| 105 | + # Zero-initialize the last BN in each residual branch, |
| 106 | + # so that the residual branch starts with zeros, and each residual block behaves like an identity. |
| 107 | + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 |
| 108 | + if zero_init_residual: |
| 109 | + for m in self.nodes(): |
| 110 | + if isinstance(m, Bottleneck): |
| 111 | + # nn.init.constant_(m.bn3.weight, 0) |
| 112 | + m.bn3.scale[:] = 0 |
| 113 | + elif isinstance(m, BasicBlock): |
| 114 | + m.bn2.scale[:] = 0 |
| 115 | + |
| 116 | + def get_bn_before_relu(self): |
| 117 | + if isinstance(self.layer1[0], Bottleneck): |
| 118 | + bn1 = self.layer1[-1].bn3 |
| 119 | + bn2 = self.layer2[-1].bn3 |
| 120 | + bn3 = self.layer3[-1].bn3 |
| 121 | + bn4 = self.layer4[-1].bn3 |
| 122 | + elif isinstance(self.layer1[0], BasicBlock): |
| 123 | + bn1 = self.layer1[-1].bn2 |
| 124 | + bn2 = self.layer2[-1].bn2 |
| 125 | + bn3 = self.layer3[-1].bn2 |
| 126 | + bn4 = self.layer4[-1].bn2 |
| 127 | + else: |
| 128 | + raise NotImplementedError('ResNet unknown block error !!!') |
| 129 | + |
| 130 | + return [bn1, bn2, bn3, bn4] |
| 131 | + |
| 132 | + def _make_layer(self, block, planes, num_blocks, stride): |
| 133 | + strides = [stride] + [1] * (num_blocks - 1) |
| 134 | + layers = [] |
| 135 | + for i in range(num_blocks): |
| 136 | + stride = strides[i] |
| 137 | + layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) |
| 138 | + self.in_planes = planes * block.expansion |
| 139 | + return bp.Sequential(*layers) |
| 140 | + |
| 141 | + def update(self, s, x, is_feat=False, preact=False): |
| 142 | + out = bm.relu(self.bn1(s, self.conv1(s, x))) |
| 143 | + f0 = out |
| 144 | + out, f1_pre = self.layer1(s, out) |
| 145 | + f1 = out |
| 146 | + out, f2_pre = self.layer2(s, out) |
| 147 | + f2 = out |
| 148 | + out, f3_pre = self.layer3(s, out) |
| 149 | + f3 = out |
| 150 | + out, f4_pre = self.layer4(s, out) |
| 151 | + f4 = out |
| 152 | + # out = self.avgpool(s, out) |
| 153 | + # out = out.reshape(128, -1) |
| 154 | + out = bm.mean(out, axis=(1, 2)) |
| 155 | + f5 = out |
| 156 | + out = self.linear(s, out) |
| 157 | + if is_feat: |
| 158 | + if preact: |
| 159 | + return [[f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], out] |
| 160 | + else: |
| 161 | + return [f0, f1, f2, f3, f4, f5], out |
| 162 | + else: |
| 163 | + return out |
| 164 | + |
| 165 | + |
| 166 | +def ResNet18(**kwargs): |
| 167 | + return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) |
| 168 | + |
| 169 | + |
| 170 | +def ResNet34(**kwargs): |
| 171 | + return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) |
| 172 | + |
| 173 | + |
| 174 | +def ResNet50(**kwargs): |
| 175 | + return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) |
| 176 | + |
| 177 | + |
| 178 | +def ResNet101(**kwargs): |
| 179 | + return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) |
| 180 | + |
| 181 | + |
| 182 | +def ResNet152(**kwargs): |
| 183 | + return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) |
| 184 | + |
| 185 | + |
| 186 | +def main(): |
| 187 | + parser = argparse.ArgumentParser(description='Classify Fashion-MNIST') |
| 188 | + parser.add_argument('-platform', default='cpu', help='platform') |
| 189 | + parser.add_argument('-batch', default=128, type=int, help='batch size') |
| 190 | + parser.add_argument('-n_epoch', default=64, type=int, metavar='N', help='number of total epochs to run') |
| 191 | + parser.add_argument('-data-dir', default='./data', type=str, help='root dir of Fashion-MNIST dataset') |
| 192 | + parser.add_argument('-out-dir', default='./logs', type=str, help='root dir for saving logs and checkpoint') |
| 193 | + parser.add_argument('-lr', default=0.1, type=float, help='learning rate') |
| 194 | + args = parser.parse_args() |
| 195 | + print(args) |
| 196 | + |
| 197 | + bm.set_platform(args.platform) |
| 198 | + |
| 199 | + out_dir = os.path.join(args.out_dir, f'b{args.batch}_lr{args.lr}_epoch{args.n_epoch}') |
| 200 | + |
| 201 | + # dataset |
| 202 | + train_set = bd.vision.MNIST(root=args.data_dir, split='train', download=True) |
| 203 | + test_set = bd.vision.MNIST(root=args.data_dir, split='test', download=True) |
| 204 | + x_train = bm.asarray(train_set.data / 255, dtype=bm.float_).reshape((-1, 28, 28, 1)) |
| 205 | + y_train = bm.asarray(train_set.targets, dtype=bm.int_) |
| 206 | + x_test = bm.asarray(test_set.data / 255, dtype=bm.float_).reshape((-1, 28, 28, 1)) |
| 207 | + y_test = bm.asarray(test_set.targets, dtype=bm.int_) |
| 208 | + |
| 209 | + with bm.training_environment(): |
| 210 | + net = ResNet18(num_classes=10) |
| 211 | + |
| 212 | + # loss function |
| 213 | + @bm.to_object(child_objs=net) |
| 214 | + def loss_fun(X, Y, fit=True): |
| 215 | + s = {'fit': fit} |
| 216 | + predictions = net(s, X) |
| 217 | + l = bp.losses.cross_entropy_loss(predictions, Y) |
| 218 | + n = bm.sum(predictions.argmax(1) == Y) |
| 219 | + return l, n |
| 220 | + |
| 221 | + grad_fun = bm.grad(loss_fun, grad_vars=net.train_vars().unique(), has_aux=True, return_value=True) |
| 222 | + |
| 223 | + # optimizer |
| 224 | + optimizer = bp.optim.Adam(bp.optim.ExponentialDecay(args.lr, 1, 0.9999), |
| 225 | + train_vars=net.train_vars().unique()) |
| 226 | + |
| 227 | + @bm.jit |
| 228 | + @bm.to_object(child_objs=(grad_fun, optimizer)) |
| 229 | + def train_fun(X, Y): |
| 230 | + grads, l, n = grad_fun(X, Y) |
| 231 | + optimizer.update(grads) |
| 232 | + return l, n |
| 233 | + |
| 234 | + predict_loss_fun = bm.jit(partial(loss_fun, fit=False), child_objs=loss_fun) |
| 235 | + |
| 236 | + os.makedirs(out_dir, exist_ok=True) |
| 237 | + with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: |
| 238 | + args_txt.write(str(args)) |
| 239 | + args_txt.write('\n') |
| 240 | + args_txt.write(' '.join(sys.argv)) |
| 241 | + |
| 242 | + max_test_acc = -1 |
| 243 | + for epoch_i in range(0, args.n_epoch): |
| 244 | + start_time = time.time() |
| 245 | + loss, train_acc = [], 0. |
| 246 | + for i in range(0, x_train.shape[0], args.batch): |
| 247 | + xs = x_train[i: i + args.batch] |
| 248 | + ys = y_train[i: i + args.batch] |
| 249 | + l, n = train_fun(xs, ys) |
| 250 | + if (i / args.batch) % 100 == 0: |
| 251 | + print(f'Epoch {epoch_i}: Train {i} batch, loss = {bm.mean(l):.4f}') |
| 252 | + loss.append(l) |
| 253 | + train_acc += n |
| 254 | + train_acc /= x_train.shape[0] |
| 255 | + train_loss = bm.mean(bm.asarray(loss)) |
| 256 | + |
| 257 | + loss, test_acc = [], 0. |
| 258 | + for i in range(0, x_test.shape[0], args.batch): |
| 259 | + xs = x_test[i: i + args.batch] |
| 260 | + ys = y_test[i: i + args.batch] |
| 261 | + l, n = predict_loss_fun(xs, ys) |
| 262 | + loss.append(l) |
| 263 | + test_acc += n |
| 264 | + test_acc /= x_test.shape[0] |
| 265 | + test_loss = bm.mean(bm.asarray(loss)) |
| 266 | + |
| 267 | + t = time.time() - start_time |
| 268 | + print(f'epoch {epoch_i}, used {t:.3f} seconds, ' |
| 269 | + f'train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, ' |
| 270 | + f'test_loss = {test_loss:.4f}, test_acc = {test_acc:.4f}') |
| 271 | + |
| 272 | + if max_test_acc < test_acc: |
| 273 | + max_test_acc = test_acc |
| 274 | + states = { |
| 275 | + 'net': net.state_dict(), |
| 276 | + 'optimizer': optimizer.state_dict(), |
| 277 | + 'epoch_i': epoch_i, |
| 278 | + 'train_acc': train_acc, |
| 279 | + 'test_acc': test_acc, |
| 280 | + } |
| 281 | + bp.checkpoints.save(out_dir, states, epoch_i) |
| 282 | + |
| 283 | + # inference |
| 284 | + state_dict = bp.checkpoints.load(out_dir) |
| 285 | + net.load_state_dict(state_dict['net']) |
| 286 | + correct_num = 0 |
| 287 | + for i in range(0, x_test.shape[0], 512): |
| 288 | + xs = x_test[i: i + 512] |
| 289 | + ys = y_test[i: i + 512] |
| 290 | + correct_num += predict_loss_fun(xs, ys)[1] |
| 291 | + print('Max test accuracy: ', correct_num / x_test.shape[0]) |
| 292 | + |
| 293 | + |
| 294 | +if __name__ == '__main__': |
| 295 | + main() |
| 296 | + # import time |
| 297 | + # |
| 298 | + # with bm.training_environment(): |
| 299 | + # net = ResNet34() |
| 300 | + # x = bm.random.randn(2, 32, 32, 1) |
| 301 | + # start = time.time() |
| 302 | + # feats, logit = net({'fit': False}, x, is_feat=True, preact=True) |
| 303 | + # end = time.time() |
| 304 | + # print(f'time: {end - start}') |
| 305 | + # |
| 306 | + # start = time.time() |
| 307 | + # feats, logit = net({'fit': False}, x, is_feat=True, preact=True) |
| 308 | + # end = time.time() |
| 309 | + # print(f'time: {end - start}') |
| 310 | + # |
| 311 | + # for f in feats: |
| 312 | + # print(f.shape, f.min().item(), f.max().item()) |
| 313 | + # print(logit.shape) |
| 314 | + # |
| 315 | + # for m in net.get_bn_before_relu(): |
| 316 | + # if isinstance(m, bp.layers.BatchNorm2D): |
| 317 | + # print('pass') |
| 318 | + # else: |
| 319 | + # print('warning') |
0 commit comments