-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathslow_main.py
More file actions
116 lines (107 loc) · 5.4 KB
/
slow_main.py
File metadata and controls
116 lines (107 loc) · 5.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from argparse import ArgumentParser, Namespace
import torch
from joblib import Parallel, delayed
from supervised.plotter import plot_flip_loss
from supervised.slow_learn import online_learn
from utils import make_log_dir, save_config
if __name__ == '__main__':
parser = ArgumentParser()
# model selection
parser.add_argument('--model', type=str,
help='model to train', default='ffnn',
choices=['ffnn', 'transformer', 'rnn', 'context_ffnn'])
# FFNN model arguments
parser.add_argument('--ffnn_dim_hidden', type=int,
help='FFNN hidden dimension', default=20)
parser.add_argument('-lffnn', '--ffnn_layers', type=int,
help='number of FFNN layers', default=2)
# transformer model arguments
parser.add_argument('-ltf', '--tf_layers', type=int,
help='number of Transformer layers', default=2)
parser.add_argument('--tf_mode', type=str,
help='training mode: auto-regressive or sequential',
default='auto', choices=['auto', 'sequential'])
parser.add_argument('-a', '--tf_activation', type=str,
help='transformer activation function', default='linear',
choices=['linear', 'relu', 'softmax'])
parser.add_argument('--apply_mask', action='store_true',
help='apply mask to the transformer')
# rnn model arguments
parser.add_argument('-lrnn', '--rnn_layers', type=int,
help='number of RNN layers', default=1)
# optimizer arguments
parser.add_argument('-lr', '--learning_rate', type=float,
help='learning rate', default=0.01)
parser.add_argument('--weight_decay', type=float,
help='regularization term', default=0.0)
# experiment arguments
parser.add_argument('-d', '--dim_feature', type=int,
help='feature dimension', default=20)
parser.add_argument('-n', '--ctxt_len', type=int,
help='context length', default=100)
parser.add_argument('-b', '--batch_size', type=int,
help='mini batch size', default=1)
parser.add_argument('--benchmark', action='store_true',
help='use per-task initialized benchmark for the models')
parser.add_argument('--seed', type=int, nargs='+',
help='random seed', default=list(range(10)))
parser.add_argument('--bin_size', type=int,
help='bin size for averaging', default=40_000)
parser.add_argument('--prefix', type=str,
help='prefix to add to the save directory', default='')
parser.add_argument('--suffix', type=str,
help='suffix to add to the save directory', default='')
parser.add_argument('--ckpt', type=int,
help='checkpoint to load', default=None)
parser.add_argument('-v', '--verbose', action='store_true',
help='print training details')
args: Namespace = parser.parse_args()
dir_dict = make_log_dir(args.seed, 'SlowChangeReg',
prefix=args.prefix, suffix=args.suffix)
save_config(dir_dict['root'], vars(args))
if args.verbose:
for k, v in vars(args).items():
print(f'{k}: {v}')
works = []
if args.model == 'ffnn':
model_args = dict(dim_input=args.dim_feature,
dim_output=1,
dim_hidden=args.ffnn_dim_hidden,
num_hidden_layers=args.ffnn_layers)
elif args.model == 'context_ffnn':
model_args = dict(dim_input=args.dim_feature,
dim_output=1,
dim_hidden=args.ffnn_dim_hidden,
num_hidden_layers=args.ffnn_layers,
n_ctxt=args.ctxt_len)
elif args.model == 'transformer':
model_args = dict(d_feat=args.dim_feature,
d_out=1,
n_ctxt=args.ctxt_len,
n_layers=args.tf_layers,
mode=args.tf_mode,
activation=args.tf_activation,
apply_mask=args.apply_mask)
elif args.model == 'rnn':
model_args = dict(dim_input=args.dim_feature,
dim_output=1,
num_layers=args.rnn_layers)
for seed in args.seed:
ckpt = None if args.ckpt is None else torch.load(os.path.join(dir_dict['ckpt'][seed],
f'ckpt_{args.ckpt}.pt'), weights_only=False)
run_config = dict(log_path=dir_dict['root'],
seed=seed,
model_key=args.model,
model_args=model_args,
dim_feature=args.dim_feature,
ctxt_len=args.ctxt_len,
lr=args.learning_rate,
batch_size=args.batch_size,
weight_decay=args.weight_decay,
benchmark=args.benchmark,
ckpt=ckpt)
works.append(delayed(online_learn)(**run_config))
Parallel(n_jobs=-1)(works)
plot_flip_loss(dir_dict, args.seed, args.model,
args.benchmark, args.bin_size)