Skip to content

Commit 8de7737

Browse files
Merge pull request #309 from jonbinney/mcts0
Raw NN play and dumb_score_raw metric
2 parents d41e908 + a5488cc commit 8de7737

File tree

5 files changed

+50
-11
lines changed

5 files changed

+50
-11
lines changed

deep_quoridor/src/agents/alphazero/mcts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,19 @@ def search_batch(self, initial_games: list[Quoridor]):
155155
max_iterations = max(num_iterations)
156156

157157
roots = [Node(g, ucb_c=self.ucb_c) for g in initial_games]
158+
159+
# When n is 0, it plays just with the NN and doesn't actually perform MCTS.
160+
# For this, we just set the visit counts to a value proportional to the prior
161+
if self.n == 0:
162+
value_batch, priors_batch = self.evaluator.evaluate_batch([node.game for node in roots])
163+
for root, value, priors in zip(roots, value_batch, priors_batch):
164+
root.expand(priors)
165+
root.backpropagate(-value)
166+
for ch in root.children:
167+
ch.visit_count = int(ch.prior * 1000)
168+
169+
return [root.children for root in roots], [-(root.value_sum / root.visit_count) for root in roots]
170+
158171
for iteration in range(max_iterations):
159172
need_evaluation = [] # (root, node)
160173
for game_idx, root in enumerate(roots):

deep_quoridor/src/metrics.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from agents import Agent
2+
from agents.alphazero.alphazero import AlphaZeroAgent
23
from agents.core.agent import AgentRegistry
34
from arena import Arena, PlayMode
45
from arena_utils import GameResult
56
from quoridor_env import env
67
from renderers.match_results import MatchResultsRenderer
78
from utils.misc import compute_elo, get_opponent_player_id
9+
from utils.subargs import override_subargs
810

911

1012
class Metrics:
@@ -78,7 +80,7 @@ def _compute_relative_elo(self, elo_table: dict[str, float], agent_name: str) ->
7880

7981
def compute(
8082
self, agent_encoded_name: str
81-
) -> tuple[int, dict[str, float], int, float, dict[str, float], dict[str, float], int, int]:
83+
) -> tuple[int, dict[str, float], int, float, dict[str, float], dict[str, float], int, int, int]:
8284
"""
8385
Evaluates the performance of a given agent by running it against a set of predefined opponents and computing its Elo rating and win percentage.
8486
@@ -95,6 +97,7 @@ def compute(
9597
- p2_win_percentages (dict[str, float]): Win percentage as player two against each oponnent.
9698
- absolute_elo (int): ELO rating obtained during the tournament
9799
- dumb_score (int): A score between 0 (perfect) and 100 (always wrong) on how the agent performs in certain basic situations
100+
- dumb_score_raw (int): for AlphaZeroAgent, same as dumb_score but with a raw network rather than MCTS. For other agents, returns dumb_score
98101
99102
Notes:
100103
- The method disables training mode for trainable agents during evaluation and restores it afterward.
@@ -142,6 +145,11 @@ def compute(
142145

143146
dumb_score = self.dumb_score(agent)
144147

148+
if isinstance(agent, AlphaZeroAgent):
149+
raw_play_encoded_name = override_subargs(play_encoded_name, {"mcts_n": 0})
150+
agent_raw = AgentRegistry.create_from_encoded_name(raw_play_encoded_name, arena.game)
151+
dumb_score_raw = self.dumb_score(agent_raw, verbose=True)
152+
145153
return (
146154
VERSION,
147155
elo_table,
@@ -151,6 +159,7 @@ def compute(
151159
p2_win_percentages,
152160
int(absolute_elo),
153161
dumb_score,
162+
dumb_score_raw,
154163
)
155164

156165
def dumb_score(self, agent: Agent, verbose: bool = False):

deep_quoridor/src/play.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def player_with_params(arg):
6868
"-mx",
6969
"--max_steps",
7070
type=int,
71-
default=10000,
72-
help="Maximum number of steps per game. Default is 10000",
71+
default=200,
72+
help="Maximum number of steps per game. Default is 200",
7373
)
7474
parser.add_argument(
7575
"--profile",

deep_quoridor/src/plugins/wandb_train.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,17 @@ def compute_tournament_metrics(self, model_filename: str) -> int:
159159
agent_encoded_name = override_subargs(self.agent_encoded_name, override_args)
160160

161161
Timer.start("benchmark")
162-
_, _, relative_elo, win_perc, p1_win_percentages, p2_win_percentages, absolute_elo, dumb_score = (
163-
self.metrics.compute(agent_encoded_name)
164-
)
162+
(
163+
_,
164+
_,
165+
relative_elo,
166+
win_perc,
167+
p1_win_percentages,
168+
p2_win_percentages,
169+
absolute_elo,
170+
dumb_score,
171+
dumb_score_raw,
172+
) = self.metrics.compute(agent_encoded_name)
165173
Timer.finish("benchmark", self.episode_count)
166174

167175
print(f"Tournament Metrics - Relative elo: {relative_elo}, win percentage: {win_perc}")
@@ -176,6 +184,7 @@ def compute_tournament_metrics(self, model_filename: str) -> int:
176184
"win_perc": win_perc,
177185
"absolute_elo": absolute_elo,
178186
"dumb_score": dumb_score,
187+
"dumb_score_raw": dumb_score_raw,
179188
"Episode": self.episode_count, # x axis
180189
}
181190

deep_quoridor/src/run_metrics.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,22 @@ def player_with_params(arg):
4949
args = parser.parse_args()
5050
m = Metrics(args.board_size, args.max_walls, args.benchmarks, args.benchmarks_t, args.max_steps, args.num_workers)
5151
table = PrettyTable()
52-
table.field_names = ["Agent", "Elo", "Relative Elo", "Win %", "Dumb Score"]
52+
table.field_names = ["Agent", "Elo", "Relative Elo", "Win %", "Dumb Score", "Raw Dumb Score"]
5353

5454
for player in args.players:
5555
player_nick = AgentRegistry.nick_from_encoded_name(player)
5656
print(f"=== Computing metrics for {player_nick} ===")
57-
_, _, relative_elo, win_perc, p1_win_percentages, p2_win_percentages, absolute_elo, dumb_score = m.compute(
58-
player
59-
)
60-
table.add_row([player_nick, absolute_elo, relative_elo, f"{win_perc:.2f}", dumb_score])
57+
(
58+
_,
59+
_,
60+
relative_elo,
61+
win_perc,
62+
p1_win_percentages,
63+
p2_win_percentages,
64+
absolute_elo,
65+
dumb_score,
66+
dumb_score_raw,
67+
) = m.compute(player)
68+
table.add_row([player_nick, absolute_elo, relative_elo, f"{win_perc:.2f}", dumb_score, dumb_score_raw])
6169

6270
print(table)

0 commit comments

Comments
 (0)