forked from BigBorg/gym_athlete
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmanage.py
More file actions
31 lines (26 loc) · 1.19 KB
/
manage.py
File metadata and controls
31 lines (26 loc) · 1.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
import argparse
from networks.cart_pole import CartPoleAthelete
from networks.acrobot import AcrobotAthlete
from utils.utils import interact_gym_env
environment_to_class = {
"CartPole-v1": CartPoleAthelete,
"Acrobot-v1": AcrobotAthlete
}
argument_parser = argparse.ArgumentParser("Gym Athlete")
argument_parser.add_argument("command", type=str, choices=["train", "eval", "interact"])
argument_parser.add_argument("environment", type=str, choices=["CartPole-v1", "Acrobot-v1"])
argument_parser.add_argument("--model_path", type=str, required=False)
parsed_arguments = argument_parser.parse_args()
tf.compat.v1.disable_eager_execution()
Athlete = environment_to_class[parsed_arguments.environment]
athlete = Athlete(parsed_arguments.environment)
if parsed_arguments.command == "train":
athlete.train(100, "saved_models/" + parsed_arguments.environment)
elif parsed_arguments.command == "eval":
if not hasattr(parsed_arguments, "model_path"):
print("请指定模型路径")
else:
athlete.estimate_model(model_path=parsed_arguments.model_path, render=True)
elif parsed_arguments.command == "interact":
interact_gym_env(parsed_arguments.environment)