Skip to content

Commit 69471be

Browse files
Add stochastic taxi (rainy+fickle) (#1315)
Co-authored-by: Mark Towers <[email protected]>
1 parent ddf5c6f commit 69471be

File tree

2 files changed

+249
-46
lines changed

2 files changed

+249
-46
lines changed

Diff for: gymnasium/envs/toy_text/taxi.py

+172-46
Original file line numberDiff line numberDiff line change
@@ -121,34 +121,39 @@ class TaxiEnv(Env):
121121
## Information
122122
123123
`step()` and `reset()` return a dict with the following keys:
124-
- p - transition proability for the state.
124+
- p - transition probability for the state.
125125
- action_mask - if actions will cause a transition to a new state.
126126
127-
As taxi is not stochastic, the transition probability is always 1.0. Implementing
128-
a transitional probability in line with the Dietterich paper ('The fickle taxi task')
129-
is a TODO.
130-
131127
For some cases, taking an action will have no effect on the state of the episode.
132128
In v0.25.0, ``info["action_mask"]`` contains a np.ndarray for each of the actions specifying
133129
if the action will change the state.
134130
135131
To sample a modifying action, use ``action = env.action_space.sample(info["action_mask"])``
136132
Or with a Q-value based algorithm ``action = np.argmax(q_values[obs, np.where(info["action_mask"] == 1)[0]])``.
137133
138-
139134
## Arguments
140135
141136
```python
142137
import gymnasium as gym
143138
gym.make('Taxi-v3')
144139
```
145140
141+
<a id="is_raining"></a>`is_raining=False`: If True the cab will move in intended direction with
142+
probability of 80% else will move in either left or right of target direction with
143+
equal probability of 10% in both directions.
144+
145+
<a id="fickle_passenger"></a>`fickle_passenger=False`: If true the passenger has a 30% chance of changing
146+
destinations when the cab has moved one square away from the passenger's source location. Passenger fickleness
147+
only happens on the first pickup and successful movement. If the passenger is dropped off at the source location
148+
and picked up again, it is not triggered again.
149+
146150
## References
147151
<a id="taxi_ref"></a>[1] T. G. Dietterich, “Hierarchical Reinforcement Learning with the MAXQ Value Function Decomposition,”
148152
Journal of Artificial Intelligence Research, vol. 13, pp. 227–303, Nov. 2000, doi: 10.1613/jair.639.
149153
150154
## Version History
151155
* v3: Map Correction + Cleaner Domain Description, v0.25.0 action masking added to the reset and step information
156+
- In Gymnasium `1.2.0` the `is_rainy` and `fickle_passenger` arguments were added to align with Dietterich, 2000
152157
* v2: Disallow Taxi start location = goal location, Update Taxi observations in the rollout, Update Taxi reward threshold.
153158
* v1: Remove (3,2) from locs, add passidx<4 check
154159
* v0: Initial version release
@@ -159,7 +164,125 @@ class TaxiEnv(Env):
159164
"render_fps": 4,
160165
}
161166

162-
def __init__(self, render_mode: Optional[str] = None):
167+
def _pickup(self, taxi_loc, pass_idx, reward):
168+
"""Computes the new location and reward for pickup action."""
169+
if pass_idx < 4 and taxi_loc == self.locs[pass_idx]:
170+
new_pass_idx = 4
171+
new_reward = reward
172+
else: # passenger not at location
173+
new_pass_idx = pass_idx
174+
new_reward = -10
175+
176+
return new_pass_idx, new_reward
177+
178+
def _dropoff(self, taxi_loc, pass_idx, dest_idx, default_reward):
179+
"""Computes the new location and reward for return dropoff action."""
180+
if (taxi_loc == self.locs[dest_idx]) and pass_idx == 4:
181+
new_pass_idx = dest_idx
182+
new_terminated = True
183+
new_reward = 20
184+
elif (taxi_loc in self.locs) and pass_idx == 4:
185+
new_pass_idx = self.locs.index(taxi_loc)
186+
new_terminated = False
187+
new_reward = default_reward
188+
else: # dropoff at wrong location
189+
new_pass_idx = pass_idx
190+
new_terminated = False
191+
new_reward = -10
192+
193+
return new_pass_idx, new_reward, new_terminated
194+
195+
def _build_dry_transitions(self, row, col, pass_idx, dest_idx, action):
196+
"""Computes the next action for a state (row, col, pass_idx, dest_idx) and action."""
197+
state = self.encode(row, col, pass_idx, dest_idx)
198+
199+
taxi_loc = (row, col)
200+
new_row, new_col, new_pass_idx = row, col, pass_idx
201+
reward = -1 # default reward when there is no pickup/dropoff
202+
terminated = False
203+
204+
if action == 0:
205+
new_row = min(row + 1, self.max_row)
206+
elif action == 1:
207+
new_row = max(row - 1, 0)
208+
if action == 2 and self.desc[1 + row, 2 * col + 2] == b":":
209+
new_col = min(col + 1, self.max_col)
210+
elif action == 3 and self.desc[1 + row, 2 * col] == b":":
211+
new_col = max(col - 1, 0)
212+
elif action == 4: # pickup
213+
new_pass_idx, reward = self._pickup(taxi_loc, new_pass_idx, reward)
214+
elif action == 5: # dropoff
215+
new_pass_idx, reward, terminated = self._dropoff(
216+
taxi_loc, new_pass_idx, dest_idx, reward
217+
)
218+
219+
new_state = self.encode(new_row, new_col, new_pass_idx, dest_idx)
220+
self.P[state][action].append((1.0, new_state, reward, terminated))
221+
222+
def _calc_new_position(self, row, col, movement, offset=0):
223+
"""Calculates the new position for a row and col to the movement."""
224+
dr, dc = movement
225+
new_row = max(0, min(row + dr, self.max_row))
226+
new_col = max(0, min(col + dc, self.max_col))
227+
if self.desc[1 + new_row, 2 * new_col + offset] == b":":
228+
return new_row, new_col
229+
else: # Default to current position if not traversable
230+
return row, col
231+
232+
def _build_rainy_transitions(self, row, col, pass_idx, dest_idx, action):
233+
"""Computes the next action for a state (row, col, pass_idx, dest_idx) and action for `is_rainy`."""
234+
state = self.encode(row, col, pass_idx, dest_idx)
235+
236+
taxi_loc = left_pos = right_pos = (row, col)
237+
new_row, new_col, new_pass_idx = row, col, pass_idx
238+
reward = -1 # default reward when there is no pickup/dropoff
239+
terminated = False
240+
241+
moves = {
242+
0: ((1, 0), (0, -1), (0, 1)), # Down
243+
1: ((-1, 0), (0, -1), (0, 1)), # Up
244+
2: ((0, 1), (1, 0), (-1, 0)), # Right
245+
3: ((0, -1), (1, 0), (-1, 0)), # Left
246+
}
247+
248+
# Check if movement is allowed
249+
if (
250+
action in {0, 1}
251+
or (action == 2 and self.desc[1 + row, 2 * col + 2] == b":")
252+
or (action == 3 and self.desc[1 + row, 2 * col] == b":")
253+
):
254+
dr, dc = moves[action][0]
255+
new_row = max(0, min(row + dr, self.max_row))
256+
new_col = max(0, min(col + dc, self.max_col))
257+
258+
left_pos = self._calc_new_position(row, col, moves[action][1], offset=2)
259+
right_pos = self._calc_new_position(row, col, moves[action][2])
260+
elif action == 4: # pickup
261+
new_pass_idx, reward = self._pickup(taxi_loc, new_pass_idx, reward)
262+
elif action == 5: # dropoff
263+
new_pass_idx, reward, terminated = self._dropoff(
264+
taxi_loc, new_pass_idx, dest_idx, reward
265+
)
266+
intended_state = self.encode(new_row, new_col, new_pass_idx, dest_idx)
267+
268+
if action <= 3:
269+
left_state = self.encode(left_pos[0], left_pos[1], new_pass_idx, dest_idx)
270+
right_state = self.encode(
271+
right_pos[0], right_pos[1], new_pass_idx, dest_idx
272+
)
273+
274+
self.P[state][action].append((0.8, intended_state, -1, terminated))
275+
self.P[state][action].append((0.1, left_state, -1, terminated))
276+
self.P[state][action].append((0.1, right_state, -1, terminated))
277+
else:
278+
self.P[state][action].append((1.0, intended_state, reward, terminated))
279+
280+
def __init__(
281+
self,
282+
render_mode: Optional[str] = None,
283+
is_rainy: bool = False,
284+
fickle_passenger: bool = False,
285+
):
163286
self.desc = np.asarray(MAP, dtype="c")
164287

165288
self.locs = locs = [(0, 0), (0, 4), (4, 0), (4, 3)]
@@ -168,14 +291,15 @@ def __init__(self, render_mode: Optional[str] = None):
168291
num_states = 500
169292
num_rows = 5
170293
num_columns = 5
171-
max_row = num_rows - 1
172-
max_col = num_columns - 1
294+
self.max_row = num_rows - 1
295+
self.max_col = num_columns - 1
173296
self.initial_state_distrib = np.zeros(num_states)
174297
num_actions = 6
175298
self.P = {
176299
state: {action: [] for action in range(num_actions)}
177300
for state in range(num_states)
178301
}
302+
179303
for row in range(num_rows):
180304
for col in range(num_columns):
181305
for pass_idx in range(len(locs) + 1): # +1 for being inside taxi
@@ -184,47 +308,29 @@ def __init__(self, render_mode: Optional[str] = None):
184308
if pass_idx < 4 and pass_idx != dest_idx:
185309
self.initial_state_distrib[state] += 1
186310
for action in range(num_actions):
187-
# defaults
188-
new_row, new_col, new_pass_idx = row, col, pass_idx
189-
reward = (
190-
-1
191-
) # default reward when there is no pickup/dropoff
192-
terminated = False
193-
taxi_loc = (row, col)
194-
195-
if action == 0:
196-
new_row = min(row + 1, max_row)
197-
elif action == 1:
198-
new_row = max(row - 1, 0)
199-
if action == 2 and self.desc[1 + row, 2 * col + 2] == b":":
200-
new_col = min(col + 1, max_col)
201-
elif action == 3 and self.desc[1 + row, 2 * col] == b":":
202-
new_col = max(col - 1, 0)
203-
elif action == 4: # pickup
204-
if pass_idx < 4 and taxi_loc == locs[pass_idx]:
205-
new_pass_idx = 4
206-
else: # passenger not at location
207-
reward = -10
208-
elif action == 5: # dropoff
209-
if (taxi_loc == locs[dest_idx]) and pass_idx == 4:
210-
new_pass_idx = dest_idx
211-
terminated = True
212-
reward = 20
213-
elif (taxi_loc in locs) and pass_idx == 4:
214-
new_pass_idx = locs.index(taxi_loc)
215-
else: # dropoff at wrong location
216-
reward = -10
217-
new_state = self.encode(
218-
new_row, new_col, new_pass_idx, dest_idx
219-
)
220-
self.P[state][action].append(
221-
(1.0, new_state, reward, terminated)
222-
)
311+
if is_rainy:
312+
self._build_rainy_transitions(
313+
row,
314+
col,
315+
pass_idx,
316+
dest_idx,
317+
action,
318+
)
319+
else:
320+
self._build_dry_transitions(
321+
row,
322+
col,
323+
pass_idx,
324+
dest_idx,
325+
action,
326+
)
223327
self.initial_state_distrib /= self.initial_state_distrib.sum()
224328
self.action_space = spaces.Discrete(num_actions)
225329
self.observation_space = spaces.Discrete(num_states)
226330

227331
self.render_mode = render_mode
332+
self.fickle_passenger = fickle_passenger
333+
self.fickle_step = self.fickle_passenger and self.np_random.random() < 0.3
228334

229335
# pygame utils
230336
self.window = None
@@ -289,9 +395,28 @@ def step(self, a):
289395
transitions = self.P[self.s][a]
290396
i = categorical_sample([t[0] for t in transitions], self.np_random)
291397
p, s, r, t = transitions[i]
292-
self.s = s
293398
self.lastaction = a
294399

400+
shadow_row, shadow_col, shadow_pass_loc, shadow_dest_idx = self.decode(self.s)
401+
taxi_row, taxi_col, pass_loc, _ = self.decode(s)
402+
403+
# If we are in the fickle step, the passenger has been in the vehicle for at least a step and this step the
404+
# position changed
405+
if (
406+
self.fickle_passenger
407+
and self.fickle_step
408+
and shadow_pass_loc == 4
409+
and (taxi_row != shadow_row or taxi_col != shadow_col)
410+
):
411+
self.fickle_step = False
412+
possible_destinations = [
413+
i for i in range(len(self.locs)) if i != shadow_dest_idx
414+
]
415+
dest_idx = self.np_random.choice(possible_destinations)
416+
s = self.encode(taxi_row, taxi_col, pass_loc, dest_idx)
417+
418+
self.s = s
419+
295420
if self.render_mode == "human":
296421
self.render()
297422
# truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
@@ -306,6 +431,7 @@ def reset(
306431
super().reset(seed=seed)
307432
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
308433
self.lastaction = None
434+
self.fickle_step = self.fickle_passenger and self.np_random.random() < 0.3
309435
self.taxi_orientation = 0
310436

311437
if self.render_mode == "human":

Diff for: tests/envs/test_env_implementation.py

+77
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,83 @@ def test_taxi_encode_decode():
209209
state, _, _, _, _ = env.step(env.action_space.sample())
210210

211211

212+
def test_taxi_is_rainy():
213+
env = TaxiEnv(is_rainy=True)
214+
for state_dict in env.P.values():
215+
for action, transitions in state_dict.items():
216+
if action <= 3:
217+
assert sum([t[0] for t in transitions]) == 1
218+
assert {t[0] for t in transitions} == {0.8, 0.1}
219+
else:
220+
assert len(transitions) == 1
221+
assert transitions[0][0] == 1.0
222+
223+
state, _ = env.reset()
224+
_, _, _, _, info = env.step(0)
225+
assert info["prob"] in {0.8, 0.1}
226+
227+
env = TaxiEnv(is_rainy=False)
228+
for state_dict in env.P.values():
229+
for action, transitions in state_dict.items():
230+
assert len(transitions) == 1
231+
assert transitions[0][0] == 1.0
232+
233+
state, _ = env.reset()
234+
_, _, _, _, info = env.step(0)
235+
assert info["prob"] == 1.0
236+
237+
238+
def test_taxi_disallowed_transitions():
239+
disallowed_transitions = [
240+
((0, 1), (0, 3)),
241+
((0, 3), (0, 1)),
242+
((1, 0), (1, 2)),
243+
((1, 2), (1, 0)),
244+
((3, 1), (3, 3)),
245+
((3, 3), (3, 1)),
246+
((3, 3), (3, 5)),
247+
((3, 5), (3, 3)),
248+
((4, 1), (4, 3)),
249+
((4, 3), (4, 1)),
250+
((4, 3), (4, 5)),
251+
((4, 5), (4, 3)),
252+
]
253+
for rain in {True, False}:
254+
env = TaxiEnv(is_rainy=rain)
255+
for state, state_dict in env.P.items():
256+
start_row, start_col, _, _ = env.decode(state)
257+
for action, transitions in state_dict.items():
258+
for transition in transitions:
259+
end_row, end_col, _, _ = env.decode(transition[1])
260+
assert (
261+
(start_row, start_col),
262+
(end_row, end_col),
263+
) not in disallowed_transitions
264+
265+
266+
def test_taxi_fickle_passenger():
267+
env = TaxiEnv(fickle_passenger=True)
268+
# This is a fickle seed, if randomness or the draws from the PRNG were recently updated, find a new seed
269+
env.reset(seed=43)
270+
state, *_ = env.step(0)
271+
taxi_row, taxi_col, pass_idx, orig_dest_idx = env.decode(state)
272+
# force taxi to passenger location
273+
env.s = env.encode(
274+
env.locs[pass_idx][0], env.locs[pass_idx][1], pass_idx, orig_dest_idx
275+
)
276+
# pick up the passenger
277+
env.step(4)
278+
if env.locs[pass_idx][0] == 0:
279+
# if we're on the top row, move down
280+
state, *_ = env.step(0)
281+
else:
282+
# otherwise move up
283+
state, *_ = env.step(1)
284+
taxi_row, taxi_col, pass_idx, dest_idx = env.decode(state)
285+
# check that passenger has changed their destination
286+
assert orig_dest_idx != dest_idx
287+
288+
212289
@pytest.mark.parametrize(
213290
"env_name",
214291
["Acrobot-v1", "CartPole-v1", "MountainCar-v0", "MountainCarContinuous-v0"],

0 commit comments

Comments
 (0)