1- import gym
1+ import gymnasium as gym
22import numpy as np
33import warnings
4- from gym .spaces import Discrete , Box
4+ from gymnasium .spaces import Discrete , Box
55import MahjongPyWrapper as pm
66
77np .set_printoptions (threshold = np .inf )
@@ -431,15 +431,16 @@ def _proceed_until_agent_turn(self):
431431 action = self .opponent_agent .select (obs , action_mask = action_mask , greedy = True )
432432 self .env .step (self .env .get_curr_player_id (), action )
433433
434- def reset (self , oya = None , game_wind = None , seed = None ):
434+ def reset (self , * , oya = None , game_wind = None , seed = None , options = None ):
435+ super ().reset (seed = seed , options = options )
435436 self .env .reset (oya = oya , game_wind = game_wind , seed = seed )
436437 self ._proceed_until_agent_turn ()
437438
438439 if self .env .is_over ():
439440 # if espisode length == 0 for the current player, ignore this game and re-start a new game
440441 return self .reset ()
441442 else :
442- return self .get_obs ()
443+ return self .get_obs (), {}
443444
444445 def step (self , action ):
445446 assert self .env .get_curr_player_id () == self .THIS_AGENT_ID
@@ -449,12 +450,12 @@ def step(self, action):
449450
450451 if self .env .is_over ():
451452 r = self .env .get_payoffs ()[self .THIS_AGENT_ID ]
452- done = True
453+ terminated = True
453454 else :
454455 r = 0
455- done = False
456+ terminated = False
456457
457- return self .env .get_obs (self .THIS_AGENT_ID ), r , done , {}
458+ return self .env .get_obs (self .THIS_AGENT_ID ), r , terminated , False , {}
458459
459460 def get_obs (self ):
460461 return self .env .get_obs (self .THIS_AGENT_ID )
0 commit comments