|
18 | 18 | "discrete": False,
|
19 | 19 | }
|
20 | 20 |
|
| 21 | +ADDITIONAL_PO_ENV_PARAMS = { |
| 22 | + # num of vehicles the agent can observe on each incoming edge |
| 23 | + "num_observed": 2, |
| 24 | + # velocity to use in reward functions |
| 25 | + "target_velocity": 30, |
| 26 | +} |
| 27 | + |
21 | 28 |
|
22 | 29 | class TrafficLightGridEnv(Env):
|
23 | 30 | """Environment used to train traffic lights to regulate traffic flow
|
@@ -57,6 +64,12 @@ class TrafficLightGridEnv(Env):
|
57 | 64 | """
|
58 | 65 |
|
59 | 66 | def __init__(self, env_params, sumo_params, scenario):
|
| 67 | + |
| 68 | + for p in ADDITIONAL_ENV_PARAMS.keys(): |
| 69 | + if p not in env_params.additional_params: |
| 70 | + raise KeyError( |
| 71 | + 'Environment parameter "{}" not supplied'.format(p)) |
| 72 | + |
60 | 73 | self.grid_array = scenario.net_params.additional_params["grid_array"]
|
61 | 74 | self.rows = self.grid_array["row_num"]
|
62 | 75 | self.cols = self.grid_array["col_num"]
|
@@ -458,15 +471,15 @@ class PO_TrafficLightGridEnv(TrafficLightGridEnv):
|
458 | 471 | def __init__(self, env_params, sumo_params, scenario):
|
459 | 472 | super().__init__(env_params, sumo_params, scenario)
|
460 | 473 |
|
| 474 | + for p in ADDITIONAL_PO_ENV_PARAMS.keys(): |
| 475 | + if p not in env_params.additional_params: |
| 476 | + raise KeyError( |
| 477 | + 'Environment parameter "{}" not supplied'.format(p)) |
| 478 | + |
461 | 479 | # number of vehicles nearest each intersection that is observed in the
|
462 | 480 | # state space; defaults to 2
|
463 | 481 | self.num_observed = env_params.additional_params.get("num_observed", 2)
|
464 | 482 |
|
465 |
| - # used while computing the reward |
466 |
| - self.env_params.additional_params["target_velocity"] = \ |
467 |
| - max(self.scenario.speed_limit(edge) |
468 |
| - for edge in self.scenario.get_edge_list()) |
469 |
| - |
470 | 483 | # used during visualization
|
471 | 484 | self.observed_ids = []
|
472 | 485 |
|
|
0 commit comments