-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_environment.py
316 lines (270 loc) · 10.3 KB
/
run_environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import os
import sys
import time
import keyboard
import argparse
import numpy as np
from stable_baselines3 import TD3, SAC
from stable_baselines3.common.env_checker import check_env
from inchworm import InchwormEnv
def train_with_sb3_agent(
model_name="inchworm_td3",
algorithm="td3",
total_timesteps=30000,
learning_rate=0.0003,
render=False
):
"""
Trains the provided saved agent (or initializes a new agent if the provided model name doesn't exist)
within the Inchworm environment for `total_timesteps` time steps. Saves the trained model to a test
directory once it finishes, or if it receives a KeyboardInterrupt.
Parameters
----------
- `model_name`: name of the zip file (minus the .zip extension) contained in `model_dir` that represents a saved pretrained agent
- `algorithm`: the algorithm to use for training. Must be one of "td3" or "sac"
- `total_timesteps`: the number of time steps to train the agent for
- `learning_rate`: the learning rate to apply to the training
- `render`: whether to render the simulation after the agent is done training
"""
model_path = f"test_models/{model_name}.zip"
algorithm_class = {"td3": TD3, "sac": SAC}.get(algorithm.lower())
assert algorithm_class is not None, f"Invalid algorithm: {algorithm}"
env = InchwormEnv(render_mode=("human" if render else "rgb_array"))
check_env(
env
) # Make sure our env is compatible with the interface that stable-baselines3 agents expect
try:
model = algorithm_class.load(model_path, env)
print("Continuing training of saved model")
except FileNotFoundError:
print("No saved model found, training new model")
model = algorithm_class("MlpPolicy", env, verbose=1, learning_rate=learning_rate)
model.set_random_seed(time.time_ns() % 2 ** 32) # Set random seed to current time
try:
model.learn(total_timesteps, progress_bar=True)
except KeyboardInterrupt:
print("Interrupted by user, saving model")
finally:
os.makedirs(os.path.dirname(model_path), exist_ok=True)
model.save(model_path)
if render:
input("Done. Press enter to view trained agent")
run_simulation_with_sb3_agent(model_name=model_name, model_dir="test_models")
else:
print("Done.")
env.close()
def run_simulation_with_sb3_agent(
model_name="inchworm_td3",
model_dir="saved_models",
algorithm="td3",
old_model=False,
evals=False
):
"""
Runs the Inchworm environment using a provided saved agent, and applies the agent's actions
to the environment without having the agent learn. For demonstration/testing purposes.
Parameters
----------
- `model_name`: name of the zip file (minus the .zip extension) contained in `model_dir` that represents a saved pretrained agent
- `model_dir`: directory path where the model is stored (with no trailing slash)
- `algorithm`: the algorithm to use for training. Must be one of "td3" or "sac"
- `old_model`: whether the model was trained with the old version of the Inchworm environment
- `evals`: whether to calculate evaluation metrics while running the agent
"""
saved_model_path = f"{model_dir}/{model_name}.zip"
algorithm_class = {"td3": TD3, "sac": SAC}.get(algorithm.lower())
assert algorithm_class is not None, f"Invalid algorithm: {algorithm}"
env = InchwormEnv(render_mode="human", old_model=old_model, evals=evals)
check_env(
env
) # Make sure our env is compatible with the interface that stable-baselines3 agents expect
try:
model = algorithm_class.load(saved_model_path, env)
print("Using specified model")
except FileNotFoundError:
print("Specified model not found")
sys.exit(1)
model.set_random_seed(time.time_ns() % 2 ** 32) # Set random seed to current time
vec_env = model.get_env()
assert vec_env is not None
obs = vec_env.reset()
while True:
try:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = vec_env.step(action)
vec_env.render("human")
except KeyboardInterrupt:
if evals:
InchwormEnv.print_evals(info[0]["evals"], "Session Evaluation")
break
def run_simulation_random():
"""
Runs the Inchworm environment while providing random actions from the action space
at each time step
"""
env = InchwormEnv(render_mode="human")
# Must reset the env before making the first call to step()
observation, info = env.reset()
for _ in range(1000):
try:
# Select a random action from the sample space
action = env.action_space.sample()
# Apply that action to the environment, store the resulting data
observation, reward, terminated, truncated, info = env.step(action)
# End current episode if necessary
if terminated or truncated:
observation, info = env.reset()
except KeyboardInterrupt:
break
def run_simulation_control():
"""
Runs the Inchworm environment while allowing the user to control the inchworm themselves
via their keyboard.
NOTE: On Unix-like machines, this Python script must be run with `sudo` in order
for the key press detection library to function
Controls
--------
- 'u'/'j': rotate the left joint clockwise and counterclockwise
- 'i'/'k': rotate the middle joint clockwise and counterclockwise
- 'o'/'l': rotate the right joint clockwise and counterclockwise
- '[': enable the left adhesion gripper
- ']': enable the right adhesion gripper
"""
env = InchwormEnv(render_mode="human")
# Must reset the env before making the first call to step()
observation, info = env.reset()
while True:
try:
# Determine action
action = get_action()
# Break on 'q' press
if action is None:
break
# Apply that action to the environment, store the resulting data
observation, reward, terminated, truncated, info = env.step(action)
# End current episode if necessary
if terminated or truncated:
observation, info = env.reset()
except KeyboardInterrupt:
break
def get_action():
if keyboard.is_pressed("q"):
return None
action = []
if keyboard.is_pressed("j"):
action.append(1)
elif keyboard.is_pressed("u"):
action.append(-1)
else:
action.append(0)
if keyboard.is_pressed("i"):
action.append(1)
elif keyboard.is_pressed("k"):
action.append(-1)
else:
action.append(0)
if keyboard.is_pressed("l"):
action.append(1)
elif keyboard.is_pressed("o"):
action.append(-1)
else:
action.append(0)
action.append(1 if keyboard.is_pressed("[") else -1)
action.append(1 if keyboard.is_pressed("]") else -1)
return np.array(action)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""Run or train an agent to control an inchworm robot"""
)
group1 = parser.add_argument_group("Functional arguments (mutually exclusive)")
group1e = group1.add_mutually_exclusive_group(required=True)
group1e.add_argument(
"-t", "--train",
action="store_true",
help="train a new/existing model in test_models/ with the TD3 algorithm",
)
group1e.add_argument(
"-r", "--run",
action="store_true",
help="run a model with the TD3 algorithm",
)
group1e.add_argument(
"-R", "--random",
action="store_true",
help="run the environment with random actions",
)
group1e.add_argument(
"-c", "--control",
action="store_true",
help="run the environment with user control",
)
group2 = parser.add_argument_group("Training and running arguments")
group2.add_argument(
"-m", "--model-name",
type=str,
help="name of the model to run (minus the .zip extension)",
)
group2.add_argument(
"-a", "--algorithm",
type=str,
default="td3",
help="algorithm to use for training/running model, either sac or td3 (default: td3)",
)
group3 = parser.add_argument_group("Running arguments")
group3.add_argument(
"-s", "--saved-dir",
action="store_true",
help="whether the model will be/is in the saved_models/ directory (otherwise test_models/)",
)
group3.add_argument(
"-e", "--eval",
action="store_true",
help="whether to print out evaluation data while running the simulation",
)
group3.add_argument(
"-o", "--old-model",
action="store_true",
help="whether the model was trained with the old version of the Inchworm environment",
)
group4 = parser.add_argument_group("Training arguments")
group4.add_argument(
"-T", "--total-timesteps",
type=int,
default=1_000_000,
help="total number of timesteps to train the model for (default: 1,000,000)",
)
group4.add_argument(
"-l", "--learning-rate",
type=float,
default=0.0003,
help="learning rate for training the model (default: 0.0003)",
)
args = parser.parse_args()
if args.train:
if args.model_name is None:
parser.error("argument -t/--train requires -m/--model-name")
if args.saved_dir:
parser.error("argument -t/--train cannot be used with -s/--saved-dir (cannot train a model in the saved_models/ directory)")
train_with_sb3_agent(
model_name=args.model_name,
algorithm=args.algorithm,
total_timesteps=args.total_timesteps,
learning_rate=args.learning_rate
)
elif args.run:
if args.model_name is None:
parser.error("argument -r/--run requires -m/--model-name")
run_simulation_with_sb3_agent(
model_name=args.model_name,
algorithm=args.algorithm,
model_dir="saved_models" if args.saved_dir else "test_models",
old_model=args.old_model,
evals=args.eval
)
elif args.random:
run_simulation_random()
elif args.control:
run_simulation_control()
else:
parser.print_help()
exit(0)