1010from ding .envs .env .base_env import BaseEnvTimestep
1111from ding .utils .registry_factory import ENV_REGISTRY
1212from gymnasium import spaces
13- from pettingzoo .classic .chess import chess_utils
14-
1513from zoo .board_games .chess .envs .chess_env import ChessEnv
14+ from pettingzoo .classic .chess import chess_utils as pz_cu
1615
1716
1817@ENV_REGISTRY .register ('chess_lightzero' )
@@ -50,16 +49,15 @@ def __init__(self, cfg=None):
5049
5150 @property
5251 def legal_actions (self ):
53- return chess_utils .legal_moves (self .board )
52+ return pz_cu .legal_moves (self .board )
5453
5554 def observe (self , agent_index ):
5655 try :
57- observation = chess_utils .get_observation (self .board , agent_index ).astype (float ) # TODO
56+ observation = pz_cu .get_observation (self .board , agent_index ).astype (float ) # TODO
5857 except Exception as e :
59- print ('debug' )
58+ print (f 'debug: { e } ' )
6059 print (f"self.board:{ self .board } " )
6160
62-
6361 # TODO:
6462 # observation = np.dstack((observation[:, :, :7], self.board_history))
6563 # We need to swap the white 6 channels with black 6 channels
@@ -75,9 +73,12 @@ def observe(self, agent_index):
7573 # observation[..., 13 * i : 13 * i + 6] = tmp
7674
7775 action_mask = np .zeros (4672 , dtype = np .int8 )
78- action_mask [chess_utils .legal_moves (self .board )] = 1
76+ action_mask [pz_cu .legal_moves (self .board )] = 1
7977 return {'observation' : observation , 'action_mask' : action_mask }
8078
79+
80+
81+
8182 def current_state (self ):
8283 """
8384 Overview:
@@ -103,7 +104,7 @@ def get_done_winner(self):
103104 if result == "*" :
104105 winner = - 1
105106 else :
106- winner = chess_utils .result_to_int (result )
107+ winner = pz_cu .result_to_int (result )
107108
108109 if not done :
109110 winner = - 1
@@ -143,7 +144,7 @@ def reset(self, start_player_index=0, init_state=None, katago_policy_init=False,
143144 self .board = chess .Board ()
144145
145146 action_mask = np .zeros (4672 , dtype = np .int8 )
146- action_mask [chess_utils .legal_moves (self .board )] = 1
147+ action_mask [pz_cu .legal_moves (self .board )] = 1
147148 # self.board_history = np.zeros((8, 8, 104), dtype=bool)
148149
149150 if self .battle_mode == 'play_with_bot_mode' or self .battle_mode == 'eval_mode' :
@@ -265,10 +266,10 @@ def _player_step(self, action):
265266 current_agent = self .current_player_index
266267
267268 # TODO: Update board history
268- # next_board = chess_utils .get_observation(self.board, current_agent)
269+ # next_board = pz_cu .get_observation(self.board, current_agent)
269270 # self.board_history = np.dstack((next_board[:, :, 7:], self.board_history[:, :, :-13]))
270271
271- chosen_move = chess_utils .action_to_move (self .board , action , current_agent )
272+ chosen_move = pz_cu .action_to_move (self .board , action , current_agent )
272273 assert chosen_move in self .board .legal_moves
273274 self .board .push (chosen_move )
274275
@@ -277,7 +278,7 @@ def _player_step(self, action):
277278 if result == "*" :
278279 reward = 0.
279280 else :
280- reward = chess_utils .result_to_int (result )
281+ reward = pz_cu .result_to_int (result )
281282
282283 if self .current_player == 1 :
283284 reward = - reward
@@ -287,7 +288,7 @@ def _player_step(self, action):
287288 info ['eval_episode_return' ] = reward
288289
289290 action_mask = np .zeros (4672 , dtype = np .int8 )
290- action_mask [chess_utils .legal_moves (self .board )] = 1
291+ action_mask [pz_cu .legal_moves (self .board )] = 1
291292
292293 obs = {
293294 'observation' : self .observe (self .current_player_index )['observation' ],
@@ -318,14 +319,14 @@ def current_player(self, value):
318319 self ._current_player = value
319320
320321 def random_action (self ):
321- action_list = chess_utils .legal_moves (self .board )
322+ action_list = pz_cu .legal_moves (self .board )
322323 return np .random .choice (action_list )
323324
324325 def simulate_action (self , action ):
325- if action not in chess_utils .legal_moves (self .board ):
326+ if action not in pz_cu .legal_moves (self .board ):
326327 raise ValueError ("action {0} on board {1} is not legal" .format (action , self .board .fen ()))
327328 new_board = copy .deepcopy (self .board )
328- new_board .push (chess_utils .action_to_move (self .board , action , self .current_player_index ))
329+ new_board .push (pz_cu .action_to_move (self .board , action , self .current_player_index ))
329330 if self .start_player_index == 0 :
330331 start_player_index = 1
331332 else :
0 commit comments