-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
156 lines (122 loc) · 7.17 KB
/
main.py
File metadata and controls
156 lines (122 loc) · 7.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import argparse
# Load the settings
parser = argparse.ArgumentParser(description='원하는 하이퍼파라미터를 설정할 수 있습니다. eg. python main.py --resolution 256 --num_epochs 20 | 도움말. python main.py -h')
parser.add_argument('--resolution', default=8, type=int, help="mnmx에서 설정한 범위(learning rate and weight offest)를 얼마나 세세하게 관찰할지 결정할 수 있습니다.\n이 값의 제곱만큼 모델 수가 결정됩니다. 너무 작으면 fractal dimension이 구해지지 않을 수 있어요!")
parser.add_argument('--num_epochs', default=100, type=int, help="학습 시 epochs를 설정합니다. 20epochs면 거의 수렴해요.")
parser.add_argument('--batch_size', default=32, type=int, help="batch size를 결정합니다.")
parser.add_argument('--target_dim', default=10, type=int, help="이미지의 라벨종류가 몇개인가요?")
parser.add_argument('--nonlinearity', default='relu', type=str, help="convoluion 직후에 적용할 activation입니다. 지금은 'relu'와 'leaky'만 지원해요.")
parser.add_argument('--optimizer', default='sgd', type=str, help="optmizer입니다. 지금은 'sgd'와 'adam'만 지원해요. ")
parser.add_argument('--c_hidden', default=[16, 32, 64], type=int, nargs='+', help="block을 통과하고 나서 나오는 텐서의 채널 수를 설정합니다.\neg. --c_hidden 16 32 64")
parser.add_argument('--tile_batch', default=1024, type=int, help="1번 training할 동안 몇 개의 모델을 동시에 굴릴지 설정합니다.\n모델 수(resolution**2)가 이거보다 많으면 절반으로 쪼개어 이거보다 작을 때까지 반복합니다.\n이 값이 클수록 한번에 많은 모델을 학습시킬 수 있지만, 너무 크면 OOM이 발생할 수 있습니다.\n2의 제곱수로 조절하세요.")
parser.add_argument('--mnmx', default=[-4, 0, -4, 0], type=int, nargs='+', help="learning rate의 범위 하한값, 상한값, weight offset의 범위 하한값, 상한값을 설정합니다. eg. --mnmx -4 0 -4 0")
parser.add_argument('--dpi', default=100, type=int, help="PNG 파일의 해상도 값입니다. 냅두셔도 됩니다.")
parser.add_argument('--figsize', default=[8, 8], type=int, nargs='+', help="lossmap의 figure size를 결정합니다. 정사각형꼴로 설정하세요.")
args = parser.parse_args()
from load_libraries import *
from protocol_save import *
from protocol_train import *
from protocol_plot import *
from datasets.mnist import *
from model.resnet_v4 import *
from typing import Any
from functools import partial
from pprint import pformat
import pandas as pd
# Metadata: hyperparmas
pd.DataFrame(args._get_kwargs()).to_csv(output_path + '/hyperparams.csv', header=False, index=False)
# Allocate the args
for arg in args._get_kwargs():
k, v = arg
if not isinstance(v, str):
exec('%s = %s' % (k, v))
else:
exec('%s = "%s"' % (k, v))
# Define activations and optimizers
if nonlinearity == 'relu':
exec('nonlinearity = nn.relu')
elif nonlinearity == 'leaky':
exec('nonlinearity = nn.leaky_relu')
if optimizer == 'sgd':
optimizer = optax.sgd
elif optimizer == 'adam':
optimizer = optax.adam
# Tiling and plotting functions
@partial(jax.vmap, in_axes=(None, 0, None))
def train_step_v(variables, lr, model):
state = TrainState.create(
apply_fn=model.apply,
params=jax.tree_util.tree_map(lambda param: param + lr[0], variables['params']),
batch_stats=variables['batch_stats'],
tx=optimizer(lr[1])
)
M_train = []
for _ in tqdm(range(num_epochs), total=num_epochs, leave=False, desc='Epochs'):
for batch in tqdm(train_ds.as_numpy_iterator(), total=total_batch, leave=False, desc='Iter'):
state, metrics = train_step(state, batch)
M_train.append(metrics)
return M_train
def make_array(metrics_v, target):
return np.vstack([metrics_v[i][target] for i in range(num_epochs)]).T # (px, epochs)
def train_step_tile(variables, lrs, model, tile_batch=tile_batch):
bs = lrs.shape[0]
particles = bs//tile_batch
if particles > 1:
print(f"Splitting tiles as {bs}->{bs//2}tiles.")
if bs > tile_batch:
metrics1 = train_step_tile(variables, lrs[:bs//2], model)
metrics2 = train_step_tile(variables, lrs[bs//2:], model)
acc_v1 = metrics1['accuracy']
loss_v1 = metrics1['loss']
acc_v2 = metrics2['accuracy']
loss_v2 = metrics2['loss']
acc = np.vstack([acc_v1, acc_v2])
loss = np.vstack([loss_v1, loss_v2])
return {'accuracy': acc, 'loss': loss}
metrics = train_step_v(variables, lrs, model)
acc, loss = make_array(metrics, 'accuracy'), make_array(metrics, 'loss')
return {'accuracy': acc, 'loss': loss}
def sketch_convmap(conv, title, saveas=None):
plot_img(conv.reshape((resolution, resolution)), mnmx, title=title, savename=saveas)
# Scaling sketch
lrs = scaling_sketch(mnmx, resolution)
# Prepare dataset
train_ds, test_ds = prepare_dataset(batch_size)
total_batch = train_ds.cardinality().numpy()
total_tbatch = test_ds.cardinality().numpy()
batch = next(iter(train_ds))
x, y = batch['image'], batch['label']
# Model loading
resnet20 = ResNet(10, nonlinearity, ResNetBlock)
variables = resnet20.init(jax.random.PRNGKey(1), x)
# Session
msg_start = 'Training start!\n\n' + \
'DEVICE: ' + f'{len(jax.devices())} ' + jax.lib.xla_bridge.get_backend().platform + '\n' + \
'DIR: ' + output_path + '\n' + \
'Hyperparams: \n' + '='*50 + '\n' + \
pformat(args) + \
'\n' + '='*50
send_alaram(msg_start)
metrics_tile = train_step_tile(variables, lrs, resnet20, tile_batch=tile_batch) # (px, epoch)
# Draw fractal image
send_alaram('Drawing...')
for i in tqdm(range(1, num_epochs), total=num_epochs-1, desc='Drawing', leave=False):
# acc_conv = convergence_measure(metrics_tile['accuracy'][:, :i+1])
acc_conv = accuracy_measure(metrics_tile['accuracy'][:, :i+1])
loss_conv = convergence_measure(metrics_tile['loss'][:, :i+1])
# Measure FD and plot the bordelines and convolution maps
acc_edges = get_edges(acc_conv, resolution)
loss_edges = get_edges(loss_conv, resolution)
acc_fd = estimate_fractal_dimension(acc_edges, output_path + f'/train/accuracy/fd/epoch{i:03d}.png')
loss_fd = estimate_fractal_dimension(loss_edges, output_path + f'/train/loss/fd/epoch{i:03d}.png')
# dessin_borderline(acc_edges, mnmx, title=f'Training-accuracy\n({i}epoch(s), FD={acc_fd})', saveas=output_path + f'/train/accuracy/border/epoch{i:03d}.png')
dessin_borderline(loss_edges, mnmx, title=f'Training-loss\n({i}epoch(s), FD={loss_fd})', saveas=output_path + f'/train/loss/border/epoch{i:03d}.png')
# draw convergence map
sketch_convmap(acc_conv, title=f'Training-accuracy\n({i}epoch(s), FD={acc_fd})', saveas=output_path + f'/train/accuracy/epoch{i:03d}.png')
sketch_convmap(loss_conv, title=f'Training-loss\n({i}epoch(s), FD={loss_fd})', saveas=output_path + f'/train/loss/epoch{i:03d}.png')
# Animation as epochs
animate_sketches(output_path + '/train/loss')
animate_sketches(output_path + '/train/loss/border')
animate_sketches(output_path + '/train/accuracy')
animate_sketches(output_path + '/train/accuracy/border')
send_alaram('Successfully over!\nDIR: ' + output_path)