Skip to content

Commit ea0f2b2

Browse files
Merge pull request #314 from jonbinney/multi-gpu
Support for running in Multi-GPU setups
2 parents a6e7335 + aa41f5e commit ea0f2b2

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

deep_quoridor/src/agents/alphazero/self_play_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
import wandb
1212
from plugins.wandb_train import WandbParams
13-
from utils import set_deterministic
13+
from utils import my_device, set_deterministic
1414
from utils.timer import Timer
1515

1616
from agents.alphazero.alphazero import AlphaZeroAgent, AlphaZeroParams
@@ -165,11 +165,11 @@ def run_self_play_games(
165165
set_deterministic(random_seed)
166166

167167
print(
168-
f"Worker {worker_id} starting, running {num_games} games ({num_parallel_games} in parallel) with random seed {random_seed}"
168+
f"Worker {worker_id} starting ({my_device()}), running {num_games} games ({num_parallel_games} in parallel) with random seed {random_seed}"
169169
)
170170

171171
wandb_run = None
172-
if wandb_params is not None:
172+
if wandb_params is not None and wandb_params.log_from_workers:
173173
run_id = f"{wandb_params.run_id()}-worker-{worker_id}"
174174
print(f"Wandb group: {wandb_params.run_id()} run: {run_id}")
175175
wandb_run = wandb.init(

deep_quoridor/src/plugins/wandb_train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ class WandbParams(SubargsBase):
3838
# How often to log training metrics
3939
log_every: int = 10
4040

41+
# Wether workers will also log to wandb (in separate runs)
42+
log_from_workers: bool = True
43+
4144
def run_id(self):
4245
return f"{self.prefix}-{self.suffix}"
4346

@@ -124,7 +127,7 @@ def upload_model(self, model_file: str, extra_files: list[str] = []) -> str:
124127

125128
artifact.save()
126129
logged_artifact = wandb.log_artifact(artifact)
127-
logged_artifact.wait(60)
130+
logged_artifact.wait(300)
128131
logged_artifact.aliases.extend([f"ep_{self.episode_count}-{self.run.id}"])
129132
logged_artifact.save()
130133

deep_quoridor/src/train_alphazero.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ def train_alphazero(
7777
# because it calls the plugin's internal _intialize method which sets up metrics.
7878
wandb_train_plugin.start_game(game=args, agent1=training_agent, agent2=training_agent)
7979
training_agent.set_wandb_run(wandb_train_plugin.run)
80-
8180
# Compute the tournament metrics with the initial model, possibly random initialized, to
8281
# be able to see how it evolves from there
82+
wandb_train_plugin.episode_count = initial_epoch * args.games_per_epoch
8383
wandb_train_plugin.compute_tournament_metrics(str(current_filename))
8484

8585
last_epoch = initial_epoch + args.epochs
@@ -122,16 +122,16 @@ def train_alphazero(
122122
current_filename = training_agent.save_model_with_suffix(f"_epoch_{epoch}")
123123
if wandb_train_plugin is not None:
124124
wandb_train_plugin.episode_count = game_num
125-
# Compute the metrics periodically and in the last epoch
126-
if (epoch + 1) % args.benchmarks_every == 0 or epoch == last_epoch - 1:
127-
wandb_train_plugin.compute_tournament_metrics(str(current_filename))
128-
129125
# Upload the model and training state
130126
with tempfile.TemporaryDirectory() as tmpdir:
131127
training_state_filename = os.path.join(tmpdir, "training_state.gz")
132128
save_training_state(training_state_filename, training_agent, wandb_train_plugin, epoch + 1, game_num)
133129
wandb_train_plugin.upload_model(str(current_filename), [training_state_filename])
134130

131+
# Compute the metrics periodically and in the last epoch
132+
if (epoch + 1) % args.benchmarks_every == 0 or epoch == last_epoch - 1:
133+
wandb_train_plugin.compute_tournament_metrics(str(current_filename))
134+
135135
Timer.log_totals()
136136

137137
# Close the arena to finish wandb run

deep_quoridor/src/utils/misc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import argparse
2+
import os
23
import random
4+
from functools import cache
35
from glob import glob
46
from pathlib import Path
57
from typing import Optional
@@ -40,8 +42,14 @@ def get_initial_random_seed():
4042
return initial_random_seed
4143

4244

45+
@cache
4346
def my_device():
4447
if torch.cuda.is_available():
48+
dc = torch.cuda.device_count()
49+
if dc > 1:
50+
gpu_n = os.getpid() % dc
51+
return torch.device(f"cuda:{gpu_n}")
52+
4553
return torch.device("cuda")
4654

4755
if torch.backends.mps.is_available():

0 commit comments

Comments
 (0)