Skip to content

Commit 560b8b4

Browse files
kjang96eugenevinitsky
authored andcommitted
Addressing issue 145 (#160)
* Addressing issue 145 * Fixed setup script env_params to reflect new changes * Fixed pep8 issues * Added new green_wave_env requirements to rllib/green_wave.py * Added new env params to the benchmark experiments as well * Added env params to baseline as well. nose2 not passing for some figure_eight scenarios
1 parent 4eb9b80 commit 560b8b4

File tree

8 files changed

+48
-16
lines changed

8 files changed

+48
-16
lines changed

examples/rllab/green_wave.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,10 @@ def run_task(*_):
106106

107107
additional_env_params = {
108108
"target_velocity": 50,
109-
"num_steps": 500,
110-
"switch_time": 3.0
109+
"switch_time": 3.0,
110+
"num_observed": 2,
111+
"discrete": False,
112+
"tl_type": "controlled"
111113
}
112114
env_params = EnvParams(additional_params=additional_env_params)
113115

examples/rllib/green_wave.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,13 @@ def get_non_flow_params(enter_speed, additional_net_params):
9595
"rl_veh": rl_veh
9696
}
9797

98-
additional_env_params = {"target_velocity": 50, "switch_time": 3.0}
98+
additional_env_params = {
99+
"target_velocity": 50,
100+
"switch_time": 3.0,
101+
"num_observed": 2,
102+
"discrete": False,
103+
"tl_type": "controlled"
104+
}
99105

100106
additional_net_params = {
101107
"speed_limit": 35,

flow/benchmarks/baselines/grid0.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,11 @@ class needed to run simulations
123123
evaluate=True, # Set to True to evaluate traffic metrics
124124
horizon=HORIZON,
125125
additional_params={
126-
"switch_time": 2.0,
126+
"target_velocity": 50,
127+
"switch_time": 2,
127128
"num_observed": 2,
128-
"tl_type": "actuated",
129+
"discrete": False,
130+
"tl_type": "controlled"
129131
},
130132
)
131133

flow/benchmarks/baselines/grid1.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,11 @@ class needed to run simulations
123123
evaluate=True, # Set to True to evaluate traffic metrics
124124
horizon=HORIZON,
125125
additional_params={
126-
"switch_time": 2.0,
126+
"target_velocity": 50,
127+
"switch_time": 2,
127128
"num_observed": 2,
128-
"tl_type": "actuated",
129+
"discrete": False,
130+
"tl_type": "controlled"
129131
},
130132
)
131133

flow/benchmarks/grid0.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,11 @@
8787
env=EnvParams(
8888
horizon=HORIZON,
8989
additional_params={
90-
"switch_time": 2.0,
90+
"target_velocity": 50,
91+
"switch_time": 2,
9192
"num_observed": 2,
93+
"discrete": False,
94+
"tl_type": "controlled"
9295
},
9396
),
9497

flow/benchmarks/grid1.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,11 @@
8787
env=EnvParams(
8888
horizon=HORIZON,
8989
additional_params={
90-
"switch_time": 2.0,
90+
"target_velocity": 50,
91+
"switch_time": 2,
9192
"num_observed": 2,
93+
"discrete": False,
94+
"tl_type": "controlled"
9295
},
9396
),
9497

flow/envs/green_wave_env.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
"discrete": False,
1919
}
2020

21+
ADDITIONAL_PO_ENV_PARAMS = {
22+
# num of vehicles the agent can observe on each incoming edge
23+
"num_observed": 2,
24+
# velocity to use in reward functions
25+
"target_velocity": 30,
26+
}
27+
2128

2229
class TrafficLightGridEnv(Env):
2330
"""Environment used to train traffic lights to regulate traffic flow
@@ -57,6 +64,12 @@ class TrafficLightGridEnv(Env):
5764
"""
5865

5966
def __init__(self, env_params, sumo_params, scenario):
67+
68+
for p in ADDITIONAL_ENV_PARAMS.keys():
69+
if p not in env_params.additional_params:
70+
raise KeyError(
71+
'Environment parameter "{}" not supplied'.format(p))
72+
6073
self.grid_array = scenario.net_params.additional_params["grid_array"]
6174
self.rows = self.grid_array["row_num"]
6275
self.cols = self.grid_array["col_num"]
@@ -458,15 +471,15 @@ class PO_TrafficLightGridEnv(TrafficLightGridEnv):
458471
def __init__(self, env_params, sumo_params, scenario):
459472
super().__init__(env_params, sumo_params, scenario)
460473

474+
for p in ADDITIONAL_PO_ENV_PARAMS.keys():
475+
if p not in env_params.additional_params:
476+
raise KeyError(
477+
'Environment parameter "{}" not supplied'.format(p))
478+
461479
# number of vehicles nearest each intersection that is observed in the
462480
# state space; defaults to 2
463481
self.num_observed = env_params.additional_params.get("num_observed", 2)
464482

465-
# used while computing the reward
466-
self.env_params.additional_params["target_velocity"] = \
467-
max(self.scenario.speed_limit(edge)
468-
for edge in self.scenario.get_edge_list())
469-
470483
# used during visualization
471484
self.observed_ids = []
472485

tests/setup_scripts.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,9 @@ def grid_mxn_exp_setup(row_num=1,
358358
# set default env_params configuration
359359
additional_env_params = {
360360
"target_velocity": 50,
361-
"num_steps": 100,
362-
"switch_time": 3.0
361+
"switch_time": 3.0,
362+
"tl_type": "controlled",
363+
"discrete": False
363364
}
364365

365366
env_params = EnvParams(

0 commit comments

Comments
 (0)