-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
149 lines (128 loc) · 6.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Note main use mp.span hence it will fork to separate pids.
Mus
"""
import argparse
import os
import sys
from pathlib import Path
import glob
import torch
import torch.multiprocessing as mp
from meta_critics.app_globals import get_running_mode, SpecTypes, AppSelector
from meta_critics.running_spec import RunningSpecError, RunningSpec
from meta_critics.rpc.rpc_trainer import run_worker
def check_all_configs(cmd: argparse.Namespace):
"""Checks all configuration files and report,
if any config invalid.
:param cmd:
:return:
"""
config_dir = Path(cmd.config).expanduser().resolve()
if config_dir.is_file():
config_dir = config_dir.parent
list_of_specs = glob.glob(f"{config_dir}/*.yaml")
try:
_current_dir = os.getcwd()
for _f in list_of_specs:
spec = RunningSpec(cmd, AppSelector.TranModel, _current_dir)
spec.check_running_config()
print(f"Checking file {_f}, no errors found.")
except RunningSpecError as rse:
print(f"Please check {_f} it contains missing section Error:", rse)
def main(cmd, spec):
"""
:param cmd:
:param spec:
:return:
"""
for world_size in range(cmd.workers, cmd.workers + 1):
mp.spawn(run_worker, args=(world_size, spec), nprocs=world_size, join=True)
print("All main done.")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Main DH-MAML")
parser.add_argument('--tune', action='store_true', required=False,
help='run ray hyperparameter optimization.')
parser.add_argument('--test', action='store_true', required=False,
help="train model for task")
parser.add_argument('--plot', action='store_true', required=False,
help="test model on task")
parser.add_argument('--train', action='store_true', required=False,
help="plots test result")
parser.add_argument('--check_specs', action='store_true', required=False,
help="will check all spec files for errors.")
parser.add_argument('--benchmark', action='store_true', required=False,
help="will measure time of execution for different number of threads")
parser.add_argument('--use-cpu', action='store_true',
help='if we want enforce cpu only.')
parser.add_argument('--config', type=str, required=True,
help="a path to the configuration json or yaml file.")
parser.add_argument('--model_file', type=str, required=False, default="default.th",
help="a path to the a model file.")
parser.add_argument('--is_verbose', action='store_true', required=False,
help="Enable verbose out during test")
parser.add_argument('--human_render', action='store_true', required=False,
help="observer agent in action")
trainer = parser.add_argument_group('trainer')
trainer.add_argument('--num_batches', type=int, default=500,
help="number of batches. Default 500.")
trainer.add_argument('--num_meta_test', type=int, default=10,
help="number of meta test batch perform. "
"i.e we perform num meta test iteration "
"for K task, Default 10.")
trainer.add_argument('--num_meta_task', type=int, default=40,
help="number of meta tasks per batch Default: 40)")
trainer.add_argument('--num_trajectory', type=int, default=20,
help="number of trajectory per task "
"per observer to collect. Default 20")
trainer.add_argument('--save_freq', type=int, default=20,
help="how often checkpoint policy.")
trainer.add_argument('--disable_meta_test', action='store_true',
required=False, help="disable meta testing.")
meta_test = parser.add_argument_group('Meta-Test')
parser.add_argument('--meta_test_freq', type=int, default=100, required=False,
help="A frequency when we want do a meta test, during training."
"note during meta test, we load a completely new policy")
misc = parser.add_argument_group('Miscellaneous')
misc.add_argument('--config_type', type=SpecTypes, default=SpecTypes.JSON, help='config file type.')
misc.add_argument('--model_dir', type=str, required=False, help='a directory where we will model data.')
misc.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
misc.add_argument('--debug_agent', action='store_true', required=False, help='Enables debug for agent.')
misc.add_argument('--debug_task_sampler', action='store_true', required=False, help='Enables debug environment.')
misc.add_argument('--debug_env', action='store_true', required=False, help='Enables debug environment.')
misc.add_argument('--gamma', type=float, default=1.0, metavar='G', help='discount factor (default: 1.0)')
misc.add_argument('--workers', type=int, default=2, help='Number of workers minimum 2. Worker 1 main Agent.')
misc.add_argument('--num_worker_threads', type=int, default=16, help='Number of workers threads.')
misc.add_argument('--rpc_timeout', type=int, default=180, help='RPC timeout settings.')
misc.add_argument('--rpc_port', type=str, default="29519", help='default rpc port.')
misc.add_argument('--disable_wandb', action='store_true', required=False, help="disable wandb")
misc = parser.add_argument_group('Miscellaneous')
args = parser.parse_args()
args.device = ('cuda' if (torch.cuda.is_available()) else 'cpu')
# torch.manual_seed(args.seed)
if args.use_cpu:
args.device = 'cpu'
mode = get_running_mode(args)
if mode is None:
print("Please select either train/test/plot/check_specs")
sys.exit(1)
if mode == AppSelector.CheckSpec:
check_all_configs(args)
sys.exit(1)
try:
current_dir = os.getcwd()
running_spec = RunningSpec(args, mode, current_dir)
running_spec.update_running_config(args)
# list of command we allow to overwrite from cmd
running_spec.update('num_batches', args.num_batches, root='meta_task')
running_spec.update('save_freq', args.save_freq, root='trainer')
running_spec.update('meta_test_freq', args.meta_test_freq, root='trainer')
running_spec.update('num_meta_test', args.num_meta_test, root='meta_task')
running_spec.update('num_meta_task', args.num_meta_task, root='meta_task')
running_spec.check_running_config()
except RunningSpecError as r_except:
print(f"Error in file {args.config}, error:", r_except)
sys.exit(100)
try:
main(args, running_spec)
except KeyboardInterrupt as kb:
print("Shutting please wait.")