Skip to content

Commit 06b34f9

Browse files
authored
Merge pull request #44 from Agony5757/develop
Fix CI and migrate from gym to gymnasium
2 parents dce8a97 + 0fb5fbf commit 06b34f9

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

pymahjong/env_pymahjong.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import gym
1+
import gymnasium as gym
22
import numpy as np
33
import warnings
4-
from gym.spaces import Discrete, Box
4+
from gymnasium.spaces import Discrete, Box
55
import MahjongPyWrapper as pm
66

77
np.set_printoptions(threshold=np.inf)
@@ -431,15 +431,16 @@ def _proceed_until_agent_turn(self):
431431
action = self.opponent_agent.select(obs, action_mask=action_mask, greedy=True)
432432
self.env.step(self.env.get_curr_player_id(), action)
433433

434-
def reset(self, oya=None, game_wind=None, seed=None):
434+
def reset(self, *, oya=None, game_wind=None, seed=None, options=None):
435+
super().reset(seed=seed, options=options)
435436
self.env.reset(oya=oya, game_wind=game_wind, seed=seed)
436437
self._proceed_until_agent_turn()
437438

438439
if self.env.is_over():
439440
# if espisode length == 0 for the current player, ignore this game and re-start a new game
440441
return self.reset()
441442
else:
442-
return self.get_obs()
443+
return self.get_obs(), {}
443444

444445
def step(self, action):
445446
assert self.env.get_curr_player_id() == self.THIS_AGENT_ID
@@ -449,12 +450,12 @@ def step(self, action):
449450

450451
if self.env.is_over():
451452
r = self.env.get_payoffs()[self.THIS_AGENT_ID]
452-
done = True
453+
terminated = True
453454
else:
454455
r = 0
455-
done = False
456+
terminated = False
456457

457-
return self.env.get_obs(self.THIS_AGENT_ID), r, done, {}
458+
return self.env.get_obs(self.THIS_AGENT_ID), r, terminated, False, {}
458459

459460
def get_obs(self):
460461
return self.env.get_obs(self.THIS_AGENT_ID)

pymahjong/test/env_mahjong.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44

55
import numpy as np
66
from copy import deepcopy
7-
import gym
7+
import gymnasium as gym
88
import pymahjong as mp
99

1010
from mahjong.shanten import Shanten
1111
from mahjong.tile import TilesConverter
1212

1313
shanten = Shanten()
1414

15-
from gym.spaces import Discrete, Box
15+
from gymnasium.spaces import Discrete, Box
1616

1717

1818
# ------------- OBS INDICES -----------

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def build_extension(self, ext):
159159
},
160160
classifiers=[_f for _f in CLASSIFIERS.split('\n') if _f],
161161
packages = ['pymahjong'],
162-
install_requires=['numpy', 'gym<=0.26.2'],
162+
install_requires=['numpy', 'gymnasium'],
163163
zip_safe = False,
164164
python_requires='>=3.8',
165165
)

0 commit comments

Comments
 (0)