@@ -19,6 +19,12 @@ def train_dqn(
1919 """
2020 Train a DQN agent to play Quoridor.
2121
22+ Julian notes:
23+ - This is for now working for a trivial 3x3 board with no walls
24+ - It teaches the agent to use black (player 2) only, against a random agent
25+ Note that in a 3x3 board with no walls, black always wins (if it wants)
26+ - It's currently not assigning negative rewards for losing
27+
2228 Args:
2329 episodes: Number of episodes to train for
2430 batch_size: Size of batches to sample from replay buffer
@@ -32,11 +38,8 @@ def train_dqn(
3238 """
3339 game = env (board_size = board_size , max_walls = max_walls , step_rewards = step_rewards )
3440
35- # Calculate action space size
36- action_size = board_size ** 2 + ((board_size - 1 ) ** 2 ) * 2
37-
3841 # Create the DQN agent
39- dqn_agent = FlatDQNAgent (board_size , action_size )
42+ dqn_agent = FlatDQNAgent (board_size , epsilon_decay = 0.9999 )
4043
4144 # Create a random opponent
4245 random_agent = RandomAgent ()
@@ -50,7 +53,6 @@ def train_dqn(
5053 for episode in range (episodes ):
5154 game .reset ()
5255
53- # Reset episode-specific variables
5456 episode_reward = 0
5557 episode_losses = []
5658
@@ -59,11 +61,12 @@ def train_dqn(
5961 observation , reward , termination , truncation , _ = game .last ()
6062
6163 # If the game is over, break the loop
64+ # TODO: Assign negative reward to the DQN agent when the opponent wins
6265 if termination or truncation :
6366 break
6467
6568 # If it's the DQN agent's turn
66- if agent_name == "player_0 " :
69+ if agent_name == "player_1 " :
6770 # Get current state
6871 state = dqn_agent .preprocess_observation (observation )
6972
@@ -73,8 +76,17 @@ def train_dqn(
7376 # Execute action
7477 game .step (action )
7578
76- # Get new state, reward, etc.
77- next_observation , reward , termination , truncation , _ = game .last ()
79+ # Get the observation and rewards for THIS agent (not the opponent)
80+ # NOTE: If we used game.last() it will return the observation and rewards for the currently active agent
81+ # which, since we already did game.step(), is now the opponent
82+ next_observation = game .observe (agent_name )
83+
84+ # Make the reward much larger than 1, to make it stand out
85+ reward = game .rewards [agent_name ] * 1000
86+
87+ # See if the game is over
88+ # TODO: Understand what is truncation and if either of these values are player dependent
89+ _ , _ , termination , truncation , _ = game .last ()
7890
7991 # Add to episode reward
8092 episode_reward += reward
@@ -115,10 +127,16 @@ def train_dqn(
115127 # Update target network periodically
116128 if episode % update_target_every == 0 :
117129 dqn_agent .update_target_network ()
118- avg_reward = sum (total_rewards [- 100 :]) / min (100 , len (total_rewards )) if total_rewards else 0.0
119- avg_loss = sum (losses [- 100 :]) / min (100 , len (losses )) if losses else 0.0
130+ avg_reward = (
131+ sum (total_rewards [- 1 * update_target_every :]) / min (update_target_every , len (total_rewards ))
132+ if total_rewards
133+ else 0.0
134+ )
135+ avg_loss = (
136+ sum (losses [- 1 * update_target_every :]) / min (update_target_every , len (losses )) if losses else 0.0
137+ )
120138 print (
121- f"Episode { episode } /{ episodes } , Avg Reward: { avg_reward :.2f} , "
139+ f"Episode { episode + 1 } /{ episodes } , Avg Reward: { avg_reward :.2f} , "
122140 f"Avg Loss: { avg_loss :.4f} , Epsilon: { dqn_agent .epsilon :.4f} "
123141 )
124142
0 commit comments