Skip to content

Commit 1337d1a

Browse files
Merge pull request #117 from adamantivm/dq/dqn-training-fixes
First working DQN, wins at trivial game
2 parents 99517ae + 875f352 commit 1337d1a

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

deep_quoridor/src/agents/flat_dqn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import deque
77
import random
88
from agents import SelfRegisteringAgent
9+
from pathlib import Path
910

1011

1112
class DQNNetwork(nn.Module):
@@ -182,8 +183,8 @@ class Pretrained01FlatDQNAgent(FlatDQNAgent):
182183
"""
183184

184185
def __init__(self, board_size, **kwargs):
185-
super(Pretrained01FlatDQNAgent, self).__init__(board_size)
186-
model_path = "/home/julian/aaae/deep-rabbit-hole/code/deep_rabbit_hole/models/dqn_flat_nostep_final.pt"
186+
super().__init__(board_size, epsilon=0.0)
187+
model_path = Path(__file__).resolve().parents[3] / "models" / "dqn_agent_final.pt"
187188
if os.path.exists(model_path):
188189
print(f"Loading pre-trained model from {model_path}")
189190
self.load_model(model_path)

deep_quoridor/src/train_dqn.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ def train_dqn(
1919
"""
2020
Train a DQN agent to play Quoridor.
2121
22+
Julian notes:
23+
- This is for now working for a trivial 3x3 board with no walls
24+
- It teaches the agent to use black (player 2) only, against a random agent
25+
Note that in a 3x3 board with no walls, black always wins (if it wants)
26+
- It's currently not assigning negative rewards for losing
27+
2228
Args:
2329
episodes: Number of episodes to train for
2430
batch_size: Size of batches to sample from replay buffer
@@ -32,11 +38,8 @@ def train_dqn(
3238
"""
3339
game = env(board_size=board_size, max_walls=max_walls, step_rewards=step_rewards)
3440

35-
# Calculate action space size
36-
action_size = board_size**2 + ((board_size - 1) ** 2) * 2
37-
3841
# Create the DQN agent
39-
dqn_agent = FlatDQNAgent(board_size, action_size)
42+
dqn_agent = FlatDQNAgent(board_size, epsilon_decay=0.9999)
4043

4144
# Create a random opponent
4245
random_agent = RandomAgent()
@@ -50,7 +53,6 @@ def train_dqn(
5053
for episode in range(episodes):
5154
game.reset()
5255

53-
# Reset episode-specific variables
5456
episode_reward = 0
5557
episode_losses = []
5658

@@ -59,11 +61,12 @@ def train_dqn(
5961
observation, reward, termination, truncation, _ = game.last()
6062

6163
# If the game is over, break the loop
64+
# TODO: Assign negative reward to the DQN agent when the opponent wins
6265
if termination or truncation:
6366
break
6467

6568
# If it's the DQN agent's turn
66-
if agent_name == "player_0":
69+
if agent_name == "player_1":
6770
# Get current state
6871
state = dqn_agent.preprocess_observation(observation)
6972

@@ -73,8 +76,17 @@ def train_dqn(
7376
# Execute action
7477
game.step(action)
7578

76-
# Get new state, reward, etc.
77-
next_observation, reward, termination, truncation, _ = game.last()
79+
# Get the observation and rewards for THIS agent (not the opponent)
80+
# NOTE: If we used game.last() it will return the observation and rewards for the currently active agent
81+
# which, since we already did game.step(), is now the opponent
82+
next_observation = game.observe(agent_name)
83+
84+
# Make the reward much larger than 1, to make it stand out
85+
reward = game.rewards[agent_name] * 1000
86+
87+
# See if the game is over
88+
# TODO: Understand what is truncation and if either of these values are player dependent
89+
_, _, termination, truncation, _ = game.last()
7890

7991
# Add to episode reward
8092
episode_reward += reward
@@ -115,10 +127,16 @@ def train_dqn(
115127
# Update target network periodically
116128
if episode % update_target_every == 0:
117129
dqn_agent.update_target_network()
118-
avg_reward = sum(total_rewards[-100:]) / min(100, len(total_rewards)) if total_rewards else 0.0
119-
avg_loss = sum(losses[-100:]) / min(100, len(losses)) if losses else 0.0
130+
avg_reward = (
131+
sum(total_rewards[-1 * update_target_every :]) / min(update_target_every, len(total_rewards))
132+
if total_rewards
133+
else 0.0
134+
)
135+
avg_loss = (
136+
sum(losses[-1 * update_target_every :]) / min(update_target_every, len(losses)) if losses else 0.0
137+
)
120138
print(
121-
f"Episode {episode}/{episodes}, Avg Reward: {avg_reward:.2f}, "
139+
f"Episode {episode + 1}/{episodes}, Avg Reward: {avg_reward:.2f}, "
122140
f"Avg Loss: {avg_loss:.4f}, Epsilon: {dqn_agent.epsilon:.4f}"
123141
)
124142

0 commit comments

Comments
 (0)