Skip to content

Commit 46ad1f5

Browse files
authored
Merge pull request berkeleyflow#45 from berkeleyflow/stress_test
adding stress tests used to diagnose race conditions
2 parents a351c4d + 41d5c85 commit 46ad1f5

File tree

4 files changed

+103
-74
lines changed

4 files changed

+103
-74
lines changed

examples/rllib/stress_test2.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

flow/utils/rllib.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212
from flow.core.traffic_lights import TrafficLights
1313
from flow.core.vehicles import Vehicles
1414

15-
import sumolib
16-
import time
17-
from copy import deepcopy
18-
1915

2016
def make_create_env(params, version=0, sumo_binary=None):
2117
"""Creates a parametrized flow environment compatible with RLlib.
@@ -75,6 +71,7 @@ def make_create_env(params, version=0, sumo_binary=None):
7571
module = __import__("flow.scenarios", fromlist=[params["generator"]])
7672
generator_class = getattr(module, params["generator"])
7773

74+
sumo_params = params['sumo']
7875
env_params = params['env']
7976
net_params = params['net']
8077
vehicles = params['veh']

stress_tests/stress_test_rl.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""
2+
Repeatedly runs one step of an environment to test for possible race conditions
3+
"""
4+
5+
import argparse
6+
import json
7+
import time
8+
import ray
9+
from ray.tune import run_experiments
10+
from ray.tune.registry import register_env
11+
12+
from flow.utils.rllib import FlowParamsEncoder
13+
14+
# use this to specify the environment to run
15+
from benchmarks.lanedrop0 import flow_params, env_name, create_env
16+
17+
# number of rollouts per training iteration
18+
N_ROLLOUTS = 50
19+
# number of parallel workers
20+
PARALLEL_ROLLOUTS = 50
21+
22+
EXAMPLE_USAGE = """
23+
example usage:
24+
python ./stress_test_rl.py PPO
25+
26+
Here the arguments are:
27+
PPO - the name of the RL algorithm you want to use for the stress test
28+
"""
29+
30+
parser = argparse.ArgumentParser(
31+
formatter_class=argparse.RawDescriptionHelpFormatter,
32+
description="Parses algorithm to run",
33+
epilog=EXAMPLE_USAGE)
34+
35+
# required input parameters
36+
parser.add_argument("alg", type=str,
37+
help="RL algorithm")
38+
39+
if __name__ == "__main__":
40+
args = parser.parse_args()
41+
alg = args.alg.upper()
42+
43+
start = time.time()
44+
print("stress test starting")
45+
ray.init(redirect_output=False)
46+
flow_params["env"].horizon = 1
47+
horizon = flow_params["env"].horizon
48+
if alg == 'ARS':
49+
import ray.rllib.ars as ars
50+
config = ars.DEFAULT_CONFIG.copy()
51+
config["num_workers"] = PARALLEL_ROLLOUTS
52+
config["num_deltas"] = PARALLEL_ROLLOUTS
53+
config["deltas_used"] = PARALLEL_ROLLOUTS
54+
elif alg == 'PPO':
55+
import ray.rllib.ppo as ppo
56+
config = ppo.DEFAULT_CONFIG.copy()
57+
config["num_workers"] = PARALLEL_ROLLOUTS
58+
config["timesteps_per_batch"] = horizon * N_ROLLOUTS
59+
config["vf_loss_coeff"] = 1.0
60+
config["kl_target"] = 0.02
61+
config["use_gae"] = True
62+
config["horizon"] = 1
63+
config["clip_param"] = 0.2
64+
config["num_sgd_iter"] = 1
65+
config["min_steps_per_task"] = 1
66+
config["sgd_batchsize"] = horizon * N_ROLLOUTS
67+
elif alg == 'ES':
68+
import ray.rllib.es as es
69+
config = es.DEFAULT_CONFIG.copy()
70+
config["num_workers"] = PARALLEL_ROLLOUTS
71+
config["episodes_per_batch"] = PARALLEL_ROLLOUTS
72+
config["timesteps_per_batch"] = PARALLEL_ROLLOUTS
73+
74+
# save the flow params for replay
75+
flow_json = json.dumps(flow_params, cls=FlowParamsEncoder, sort_keys=True,
76+
indent=4)
77+
config['env_config']['flow_params'] = flow_json
78+
79+
# Register as rllib env
80+
register_env(env_name, create_env)
81+
82+
trials = run_experiments({
83+
"highway_stabilize": {
84+
"run": alg, # Pulled from command line args
85+
"env": env_name,
86+
"config": {
87+
**config
88+
},
89+
"max_failures": 999,
90+
"stop": {"training_iteration": 50000},
91+
"repeat": 1,
92+
"trial_resources": {
93+
"cpu": 1,
94+
"gpu": 0,
95+
"extra_cpu": PARALLEL_ROLLOUTS - 1,
96+
},
97+
},
98+
})
99+
100+
end = time.time()
101+
102+
print("Stress test took " + str(end-start))
File renamed without changes.

0 commit comments

Comments
 (0)