-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtd_main.py
More file actions
110 lines (101 loc) · 4.99 KB
/
td_main.py
File metadata and controls
110 lines (101 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
from argparse import ArgumentParser, Namespace
import torch
from joblib import Parallel, delayed
from rl.td_learn import online_learn
from utils import make_log_dir, save_config
if __name__ == '__main__':
from rl.td_task import TDTask
parser = ArgumentParser()
# model selection
parser.add_argument('--model', type=str,
help='model to train', default='ffnn',
choices=['ffnn', 'transformer', 'rnn', 'context_ffnn'])
# FFNN model arguments
parser.add_argument('--ffnn_dim_hidden', type=int,
help='FFNN hidden dimension', default=30)
parser.add_argument('-lffnn', '--ffnn_layers', type=int,
help='number of FFNN layers', default=2)
# transformer model arguments
parser.add_argument('-ltf', '--tf_layers', type=int,
help='number of Transformer layers', default=6)
parser.add_argument('--tf_mode', type=str,
help='training mode: auto-regressive or sequential',
default='auto', choices=['auto', 'sequential'])
parser.add_argument('-a', '--tf_activation', type=str,
help='transformer activation function', default='softmax',
choices=['linear', 'relu', 'softmax'])
parser.add_argument('--apply_mask', action='store_true',
help='apply mask to the transformer')
# rnn model arguments
parser.add_argument('-lrnn', '--rnn_layers', type=int,
help='number of RNN layers', default=6)
# optimizer arguments
parser.add_argument('-lr', type=float,
help='learning rate', default=0.003)
parser.add_argument('--weight_decay', type=float,
help='regularization term', default=0.0)
# experiment arguments
parser.add_argument('-d', '--dim_feature', type=int,
help='feature dimension', default=4)
parser.add_argument('-n', '--ctxt_len', type=int,
help='context length', default=100)
parser.add_argument('--benchmark', action='store_true',
help='use per-task initialized benchmark for the models')
parser.add_argument('--seed', type=int, nargs='+',
help='random seed', default=list(range(10)))
parser.add_argument('--bin_size', type=int,
help='bin size for averaging', default=30)
parser.add_argument('--prefix', type=str,
help='prefix to add to the save directory', default='')
parser.add_argument('--suffix', type=str,
help='suffix to add to the save directory', default='')
parser.add_argument('--ckpt', type=int,
help='checkpoint to load', default=None)
parser.add_argument('-v', '--verbose', action='store_true',
help='print training details')
args: Namespace = parser.parse_args()
dir_dict = make_log_dir(args.seed, 'TD',
prefix=args.prefix, suffix=args.suffix)
save_config(dir_dict['root'], vars(args))
if args.verbose:
for k, v in vars(args).items():
print(f'{k}: {v}')
works = []
if args.model == 'ffnn':
model_args = dict(d_feat=args.dim_feature,
d_out=1,
dim_hidden=args.ffnn_dim_hidden,
n_hidden_layers=args.ffnn_layers)
elif args.model == 'context_ffnn':
model_args = dict(d_feat=args.dim_feature,
d_out=1,
dim_hidden=args.ffnn_dim_hidden,
n_hidden_layers=args.ffnn_layers,
n_ctxt=args.ctxt_len)
elif args.model == 'transformer':
model_args = dict(d_feat=args.dim_feature,
d_out=1,
n_ctxt=args.ctxt_len,
n_layers=args.tf_layers,
mode=args.tf_mode,
activation=args.tf_activation,
apply_mask=args.apply_mask)
elif args.model == 'rnn':
model_args = dict(d_feat=args.dim_feature,
d_out=1,
n_layers=args.rnn_layers)
for seed in args.seed:
ckpt = None if args.ckpt is None else torch.load(os.path.join(dir_dict['ckpt'][seed],
f'ckpt_{args.ckpt}.pt'), weights_only=False)
run_config = dict(log_path=dir_dict['root'],
seed=seed,
model_key=args.model,
model_args=model_args,
ctxt_len=args.ctxt_len,
lr=args.lr,
weight_decay=args.weight_decay,
benchmark=args.benchmark,
ckpt=ckpt)
works.append(delayed(online_learn)(**run_config))
Parallel(n_jobs=-1)(works)