-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
218 lines (173 loc) · 8.75 KB
/
agent.py
File metadata and controls
218 lines (173 loc) · 8.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""
Snake Game AI Agent using Deep Q-Learning (DQN)
This module contains the AI agent that learns to play Snake using reinforcement learning.
The agent uses a neural network to predict the best actions based on the current game state.
"""
import torch
import random
import numpy as np
from collections import deque
from typing import List, Tuple, Optional
from game import SnakeGameAI, Direction, Point
from model import Linear_QNet, QTrainer
from helper import plot
# Training hyperparameters - these control how the AI learns
MAX_MEMORY = 100_000 # Maximum number of game experiences to remember
BATCH_SIZE = 1000 # Number of experiences to train on at once
LEARNING_RATE = 0.001 # How fast the AI learns (smaller = more stable)
DISCOUNT_FACTOR = 0.9 # How much the AI values future rewards vs immediate ones
BLOCK_SIZE = 20 # Size of game blocks (should match game.py)
class Agent:
"""
Deep Q-Learning Agent for Snake Game
This agent learns to play Snake by:
1. Observing the game state (11 features about dangers and food location)
2. Choosing actions based on Q-values from a neural network
3. Learning from rewards using experience replay
"""
def __init__(self):
"""Initialize the agent with default parameters"""
self.games_played = 0
self.exploration_rate = 0 # How often to make random moves (epsilon)
self.discount_factor = DISCOUNT_FACTOR # How much future rewards matter
self.replay_memory = deque(maxlen=MAX_MEMORY) # Stores past experiences
# Neural network: 11 inputs -> 256 hidden -> 3 outputs (straight, right, left)
self.brain = Linear_QNet(input_size=11, hidden_size=256, output_size=3)
self.trainer = QTrainer(self.brain, learning_rate=LEARNING_RATE, discount_factor=self.discount_factor)
def get_state(self, game: SnakeGameAI) -> np.ndarray:
"""
Convert game state into 11 numerical features the AI can understand
Features:
- 3 danger detection (straight, right turn, left turn)
- 4 current direction (one-hot encoded)
- 4 food location relative to head (left, right, up, down)
"""
head = game.snake[0]
# Points one block away in each direction
point_left = Point(head.x - BLOCK_SIZE, head.y)
point_right = Point(head.x + BLOCK_SIZE, head.y)
point_up = Point(head.x, head.y - BLOCK_SIZE)
point_down = Point(head.x, head.y + BLOCK_SIZE)
# Current direction flags
moving_left = game.direction == Direction.LEFT
moving_right = game.direction == Direction.RIGHT
moving_up = game.direction == Direction.UP
moving_down = game.direction == Direction.DOWN
# Build state vector
state_features = [
# Danger detection: will we hit something if we continue straight?
(moving_right and game.is_collision(point_right)) or
(moving_left and game.is_collision(point_left)) or
(moving_up and game.is_collision(point_up)) or
(moving_down and game.is_collision(point_down)),
# Danger detection: will we hit something if we turn right?
(moving_up and game.is_collision(point_right)) or
(moving_down and game.is_collision(point_left)) or
(moving_left and game.is_collision(point_up)) or
(moving_right and game.is_collision(point_down)),
# Danger detection: will we hit something if we turn left?
(moving_down and game.is_collision(point_right)) or
(moving_up and game.is_collision(point_left)) or
(moving_right and game.is_collision(point_up)) or
(moving_left and game.is_collision(point_down)),
# Current movement direction (one-hot encoded)
moving_left,
moving_right,
moving_up,
moving_down,
# Food location relative to snake head
game.food.x < game.head.x, # food is to the left
game.food.x > game.head.x, # food is to the right
game.food.y < game.head.y, # food is above (y decreases upward)
game.food.y > game.head.y # food is below
]
return np.array(state_features, dtype=int)
def remember_experience(self, state: np.ndarray, action: List[int], reward: float,
next_state: np.ndarray, game_over: bool) -> None:
"""Store a game experience in replay memory for later learning"""
self.replay_memory.append((state, action, reward, next_state, game_over))
def train_on_batch(self) -> None:
"""Train the neural network on a batch of past experiences"""
if len(self.replay_memory) > BATCH_SIZE:
# Sample random experiences for diverse training
batch_sample = random.sample(self.replay_memory, BATCH_SIZE)
else:
# Use all available experiences if we don't have enough yet
batch_sample = self.replay_memory
# Unpack the batch into separate arrays
states, actions, rewards, next_states, game_overs = zip(*batch_sample)
self.trainer.train_step(states, actions, rewards, next_states, game_overs)
def train_on_last_move(self, state: np.ndarray, action: List[int], reward: float,
next_state: np.ndarray, game_over: bool) -> None:
"""Train immediately on the most recent move (helps with immediate feedback)"""
self.trainer.train_step(state, action, reward, next_state, game_over)
def choose_action(self, state: np.ndarray) -> List[int]:
"""
Choose an action based on current state using epsilon-greedy strategy
Returns action as [straight, right, left] where exactly one element is 1
"""
# Exploration rate decreases as agent gains experience
self.exploration_rate = max(0, 80 - self.games_played)
action = [0, 0, 0] # [straight, right, left]
# Sometimes make random moves to explore new strategies
if random.randint(0, 200) < self.exploration_rate:
random_action = random.randint(0, 2)
action[random_action] = 1
else:
# Use the neural network to pick the best action
state_tensor = torch.tensor(state, dtype=torch.float)
action_values = self.brain(state_tensor) # Get Q-values for each action
best_action = torch.argmax(action_values).item()
action[best_action] = 1
return action
def train_agent() -> None:
"""
Main training loop - teaches the AI to play Snake
The agent plays many games, learning from each one to get better over time.
"""
print("Starting Snake AI training...")
print("Close the game window or press Ctrl+C to stop training\n")
# Track progress
scores_history = []
mean_scores_history = []
total_score = 0
best_score = 0
# Create our AI agent and game environment
agent = Agent()
game = SnakeGameAI()
try:
while True:
# Get current game state
current_state = agent.get_state(game)
# Agent decides what to do
action = agent.choose_action(current_state)
# Take the action and see what happens
reward, game_over, score = game.play_step(action)
new_state = agent.get_state(game)
# Learn from this immediate experience
agent.train_on_last_move(current_state, action, reward, new_state, game_over)
# Remember this experience for future learning
agent.remember_experience(current_state, action, reward, new_state, game_over)
if game_over:
# Game ended - time to learn from all our experiences
game.reset()
agent.games_played += 1
agent.train_on_batch()
# Track our progress
if score > best_score:
best_score = score
agent.brain.save() # Save the improved model
print(f"New best score: {best_score}!")
print(f"Game {agent.games_played:4d} | Score: {score:2d} | Record: {best_score:2d}")
# Update our progress tracking
scores_history.append(score)
total_score += score
average_score = total_score / agent.games_played
mean_scores_history.append(average_score)
# Show progress graph
plot(scores_history, mean_scores_history)
except KeyboardInterrupt:
print(f"\nTraining stopped by user after {agent.games_played} games")
print(f"Best score achieved: {best_score}")
if __name__ == '__main__':
train_agent()