@@ -121,34 +121,39 @@ class TaxiEnv(Env):
121
121
## Information
122
122
123
123
`step()` and `reset()` return a dict with the following keys:
124
- - p - transition proability for the state.
124
+ - p - transition probability for the state.
125
125
- action_mask - if actions will cause a transition to a new state.
126
126
127
- As taxi is not stochastic, the transition probability is always 1.0. Implementing
128
- a transitional probability in line with the Dietterich paper ('The fickle taxi task')
129
- is a TODO.
130
-
131
127
For some cases, taking an action will have no effect on the state of the episode.
132
128
In v0.25.0, ``info["action_mask"]`` contains a np.ndarray for each of the actions specifying
133
129
if the action will change the state.
134
130
135
131
To sample a modifying action, use ``action = env.action_space.sample(info["action_mask"])``
136
132
Or with a Q-value based algorithm ``action = np.argmax(q_values[obs, np.where(info["action_mask"] == 1)[0]])``.
137
133
138
-
139
134
## Arguments
140
135
141
136
```python
142
137
import gymnasium as gym
143
138
gym.make('Taxi-v3')
144
139
```
145
140
141
+ <a id="is_raining"></a>`is_raining=False`: If True the cab will move in intended direction with
142
+ probability of 80% else will move in either left or right of target direction with
143
+ equal probability of 10% in both directions.
144
+
145
+ <a id="fickle_passenger"></a>`fickle_passenger=False`: If true the passenger has a 30% chance of changing
146
+ destinations when the cab has moved one square away from the passenger's source location. Passenger fickleness
147
+ only happens on the first pickup and successful movement. If the passenger is dropped off at the source location
148
+ and picked up again, it is not triggered again.
149
+
146
150
## References
147
151
<a id="taxi_ref"></a>[1] T. G. Dietterich, “Hierarchical Reinforcement Learning with the MAXQ Value Function Decomposition,”
148
152
Journal of Artificial Intelligence Research, vol. 13, pp. 227–303, Nov. 2000, doi: 10.1613/jair.639.
149
153
150
154
## Version History
151
155
* v3: Map Correction + Cleaner Domain Description, v0.25.0 action masking added to the reset and step information
156
+ - In Gymnasium `1.2.0` the `is_rainy` and `fickle_passenger` arguments were added to align with Dietterich, 2000
152
157
* v2: Disallow Taxi start location = goal location, Update Taxi observations in the rollout, Update Taxi reward threshold.
153
158
* v1: Remove (3,2) from locs, add passidx<4 check
154
159
* v0: Initial version release
@@ -159,7 +164,125 @@ class TaxiEnv(Env):
159
164
"render_fps" : 4 ,
160
165
}
161
166
162
- def __init__ (self , render_mode : Optional [str ] = None ):
167
+ def _pickup (self , taxi_loc , pass_idx , reward ):
168
+ """Computes the new location and reward for pickup action."""
169
+ if pass_idx < 4 and taxi_loc == self .locs [pass_idx ]:
170
+ new_pass_idx = 4
171
+ new_reward = reward
172
+ else : # passenger not at location
173
+ new_pass_idx = pass_idx
174
+ new_reward = - 10
175
+
176
+ return new_pass_idx , new_reward
177
+
178
+ def _dropoff (self , taxi_loc , pass_idx , dest_idx , default_reward ):
179
+ """Computes the new location and reward for return dropoff action."""
180
+ if (taxi_loc == self .locs [dest_idx ]) and pass_idx == 4 :
181
+ new_pass_idx = dest_idx
182
+ new_terminated = True
183
+ new_reward = 20
184
+ elif (taxi_loc in self .locs ) and pass_idx == 4 :
185
+ new_pass_idx = self .locs .index (taxi_loc )
186
+ new_terminated = False
187
+ new_reward = default_reward
188
+ else : # dropoff at wrong location
189
+ new_pass_idx = pass_idx
190
+ new_terminated = False
191
+ new_reward = - 10
192
+
193
+ return new_pass_idx , new_reward , new_terminated
194
+
195
+ def _build_dry_transitions (self , row , col , pass_idx , dest_idx , action ):
196
+ """Computes the next action for a state (row, col, pass_idx, dest_idx) and action."""
197
+ state = self .encode (row , col , pass_idx , dest_idx )
198
+
199
+ taxi_loc = (row , col )
200
+ new_row , new_col , new_pass_idx = row , col , pass_idx
201
+ reward = - 1 # default reward when there is no pickup/dropoff
202
+ terminated = False
203
+
204
+ if action == 0 :
205
+ new_row = min (row + 1 , self .max_row )
206
+ elif action == 1 :
207
+ new_row = max (row - 1 , 0 )
208
+ if action == 2 and self .desc [1 + row , 2 * col + 2 ] == b":" :
209
+ new_col = min (col + 1 , self .max_col )
210
+ elif action == 3 and self .desc [1 + row , 2 * col ] == b":" :
211
+ new_col = max (col - 1 , 0 )
212
+ elif action == 4 : # pickup
213
+ new_pass_idx , reward = self ._pickup (taxi_loc , new_pass_idx , reward )
214
+ elif action == 5 : # dropoff
215
+ new_pass_idx , reward , terminated = self ._dropoff (
216
+ taxi_loc , new_pass_idx , dest_idx , reward
217
+ )
218
+
219
+ new_state = self .encode (new_row , new_col , new_pass_idx , dest_idx )
220
+ self .P [state ][action ].append ((1.0 , new_state , reward , terminated ))
221
+
222
+ def _calc_new_position (self , row , col , movement , offset = 0 ):
223
+ """Calculates the new position for a row and col to the movement."""
224
+ dr , dc = movement
225
+ new_row = max (0 , min (row + dr , self .max_row ))
226
+ new_col = max (0 , min (col + dc , self .max_col ))
227
+ if self .desc [1 + new_row , 2 * new_col + offset ] == b":" :
228
+ return new_row , new_col
229
+ else : # Default to current position if not traversable
230
+ return row , col
231
+
232
+ def _build_rainy_transitions (self , row , col , pass_idx , dest_idx , action ):
233
+ """Computes the next action for a state (row, col, pass_idx, dest_idx) and action for `is_rainy`."""
234
+ state = self .encode (row , col , pass_idx , dest_idx )
235
+
236
+ taxi_loc = left_pos = right_pos = (row , col )
237
+ new_row , new_col , new_pass_idx = row , col , pass_idx
238
+ reward = - 1 # default reward when there is no pickup/dropoff
239
+ terminated = False
240
+
241
+ moves = {
242
+ 0 : ((1 , 0 ), (0 , - 1 ), (0 , 1 )), # Down
243
+ 1 : ((- 1 , 0 ), (0 , - 1 ), (0 , 1 )), # Up
244
+ 2 : ((0 , 1 ), (1 , 0 ), (- 1 , 0 )), # Right
245
+ 3 : ((0 , - 1 ), (1 , 0 ), (- 1 , 0 )), # Left
246
+ }
247
+
248
+ # Check if movement is allowed
249
+ if (
250
+ action in {0 , 1 }
251
+ or (action == 2 and self .desc [1 + row , 2 * col + 2 ] == b":" )
252
+ or (action == 3 and self .desc [1 + row , 2 * col ] == b":" )
253
+ ):
254
+ dr , dc = moves [action ][0 ]
255
+ new_row = max (0 , min (row + dr , self .max_row ))
256
+ new_col = max (0 , min (col + dc , self .max_col ))
257
+
258
+ left_pos = self ._calc_new_position (row , col , moves [action ][1 ], offset = 2 )
259
+ right_pos = self ._calc_new_position (row , col , moves [action ][2 ])
260
+ elif action == 4 : # pickup
261
+ new_pass_idx , reward = self ._pickup (taxi_loc , new_pass_idx , reward )
262
+ elif action == 5 : # dropoff
263
+ new_pass_idx , reward , terminated = self ._dropoff (
264
+ taxi_loc , new_pass_idx , dest_idx , reward
265
+ )
266
+ intended_state = self .encode (new_row , new_col , new_pass_idx , dest_idx )
267
+
268
+ if action <= 3 :
269
+ left_state = self .encode (left_pos [0 ], left_pos [1 ], new_pass_idx , dest_idx )
270
+ right_state = self .encode (
271
+ right_pos [0 ], right_pos [1 ], new_pass_idx , dest_idx
272
+ )
273
+
274
+ self .P [state ][action ].append ((0.8 , intended_state , - 1 , terminated ))
275
+ self .P [state ][action ].append ((0.1 , left_state , - 1 , terminated ))
276
+ self .P [state ][action ].append ((0.1 , right_state , - 1 , terminated ))
277
+ else :
278
+ self .P [state ][action ].append ((1.0 , intended_state , reward , terminated ))
279
+
280
+ def __init__ (
281
+ self ,
282
+ render_mode : Optional [str ] = None ,
283
+ is_rainy : bool = False ,
284
+ fickle_passenger : bool = False ,
285
+ ):
163
286
self .desc = np .asarray (MAP , dtype = "c" )
164
287
165
288
self .locs = locs = [(0 , 0 ), (0 , 4 ), (4 , 0 ), (4 , 3 )]
@@ -168,14 +291,15 @@ def __init__(self, render_mode: Optional[str] = None):
168
291
num_states = 500
169
292
num_rows = 5
170
293
num_columns = 5
171
- max_row = num_rows - 1
172
- max_col = num_columns - 1
294
+ self . max_row = num_rows - 1
295
+ self . max_col = num_columns - 1
173
296
self .initial_state_distrib = np .zeros (num_states )
174
297
num_actions = 6
175
298
self .P = {
176
299
state : {action : [] for action in range (num_actions )}
177
300
for state in range (num_states )
178
301
}
302
+
179
303
for row in range (num_rows ):
180
304
for col in range (num_columns ):
181
305
for pass_idx in range (len (locs ) + 1 ): # +1 for being inside taxi
@@ -184,47 +308,29 @@ def __init__(self, render_mode: Optional[str] = None):
184
308
if pass_idx < 4 and pass_idx != dest_idx :
185
309
self .initial_state_distrib [state ] += 1
186
310
for action in range (num_actions ):
187
- # defaults
188
- new_row , new_col , new_pass_idx = row , col , pass_idx
189
- reward = (
190
- - 1
191
- ) # default reward when there is no pickup/dropoff
192
- terminated = False
193
- taxi_loc = (row , col )
194
-
195
- if action == 0 :
196
- new_row = min (row + 1 , max_row )
197
- elif action == 1 :
198
- new_row = max (row - 1 , 0 )
199
- if action == 2 and self .desc [1 + row , 2 * col + 2 ] == b":" :
200
- new_col = min (col + 1 , max_col )
201
- elif action == 3 and self .desc [1 + row , 2 * col ] == b":" :
202
- new_col = max (col - 1 , 0 )
203
- elif action == 4 : # pickup
204
- if pass_idx < 4 and taxi_loc == locs [pass_idx ]:
205
- new_pass_idx = 4
206
- else : # passenger not at location
207
- reward = - 10
208
- elif action == 5 : # dropoff
209
- if (taxi_loc == locs [dest_idx ]) and pass_idx == 4 :
210
- new_pass_idx = dest_idx
211
- terminated = True
212
- reward = 20
213
- elif (taxi_loc in locs ) and pass_idx == 4 :
214
- new_pass_idx = locs .index (taxi_loc )
215
- else : # dropoff at wrong location
216
- reward = - 10
217
- new_state = self .encode (
218
- new_row , new_col , new_pass_idx , dest_idx
219
- )
220
- self .P [state ][action ].append (
221
- (1.0 , new_state , reward , terminated )
222
- )
311
+ if is_rainy :
312
+ self ._build_rainy_transitions (
313
+ row ,
314
+ col ,
315
+ pass_idx ,
316
+ dest_idx ,
317
+ action ,
318
+ )
319
+ else :
320
+ self ._build_dry_transitions (
321
+ row ,
322
+ col ,
323
+ pass_idx ,
324
+ dest_idx ,
325
+ action ,
326
+ )
223
327
self .initial_state_distrib /= self .initial_state_distrib .sum ()
224
328
self .action_space = spaces .Discrete (num_actions )
225
329
self .observation_space = spaces .Discrete (num_states )
226
330
227
331
self .render_mode = render_mode
332
+ self .fickle_passenger = fickle_passenger
333
+ self .fickle_step = self .fickle_passenger and self .np_random .random () < 0.3
228
334
229
335
# pygame utils
230
336
self .window = None
@@ -289,9 +395,28 @@ def step(self, a):
289
395
transitions = self .P [self .s ][a ]
290
396
i = categorical_sample ([t [0 ] for t in transitions ], self .np_random )
291
397
p , s , r , t = transitions [i ]
292
- self .s = s
293
398
self .lastaction = a
294
399
400
+ shadow_row , shadow_col , shadow_pass_loc , shadow_dest_idx = self .decode (self .s )
401
+ taxi_row , taxi_col , pass_loc , _ = self .decode (s )
402
+
403
+ # If we are in the fickle step, the passenger has been in the vehicle for at least a step and this step the
404
+ # position changed
405
+ if (
406
+ self .fickle_passenger
407
+ and self .fickle_step
408
+ and shadow_pass_loc == 4
409
+ and (taxi_row != shadow_row or taxi_col != shadow_col )
410
+ ):
411
+ self .fickle_step = False
412
+ possible_destinations = [
413
+ i for i in range (len (self .locs )) if i != shadow_dest_idx
414
+ ]
415
+ dest_idx = self .np_random .choice (possible_destinations )
416
+ s = self .encode (taxi_row , taxi_col , pass_loc , dest_idx )
417
+
418
+ self .s = s
419
+
295
420
if self .render_mode == "human" :
296
421
self .render ()
297
422
# truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
@@ -306,6 +431,7 @@ def reset(
306
431
super ().reset (seed = seed )
307
432
self .s = categorical_sample (self .initial_state_distrib , self .np_random )
308
433
self .lastaction = None
434
+ self .fickle_step = self .fickle_passenger and self .np_random .random () < 0.3
309
435
self .taxi_orientation = 0
310
436
311
437
if self .render_mode == "human" :
0 commit comments