Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deep_quoridor/src/train_v2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import argparse
import multiprocessing as mp
import os
import subprocess
import time
from pathlib import Path

from v2 import benchmarks, load_config_and_setup_run, self_play, train
from v2.common import ShutdownSignal

# Prevents getting messages in the console every few lines telling you to install weave
os.environ["WANDB_DISABLE_WEAVE"] = "true"

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train Quoridor agent")
parser.add_argument("config_file", type=str, help="Path to YAML configuration file")
Expand Down
10 changes: 10 additions & 0 deletions deep_quoridor/src/v2/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ def alphazero_params_dict_from_config(
"mcts_ucb_c": config.alphazero.mcts_c_puct,
}

if config.training.initial_model:
im = config.training.initial_model
if im.file:
params_dict["model_filename"] = im.file
if im.wandb_alias:
params_dict["wandb_alias"] = im.wandb_alias
params_dict["wandb_project"] = im.wandb_project or (
config.wandb.project if config.wandb else "deep_quoridor"
)

# Add network config
if config.alphazero.network.type == "mlp":
params_dict.update(
Expand Down
14 changes: 14 additions & 0 deletions deep_quoridor/src/v2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ class SelfPlayConfig(StrictBaseModel):
rust_selfplay_binary: Optional[str] = None


class InitialModel(StrictBaseModel):
file: Optional[str] = None
wandb_project: Optional[str] = None
wandb_alias: Optional[str] = None

@field_validator("wandb_alias")
@classmethod
def file_and_wandb_mutually_exclusive(cls, v, info):
if v is not None and info.data.get("file") is not None:
raise ValueError("Cannot specify both 'file' and 'wandb_alias' in initial_model")
return v


class TrainingConfig(StrictBaseModel):
games_per_training_step: float
learning_rate: float
Expand All @@ -80,6 +93,7 @@ class TrainingConfig(StrictBaseModel):
model_save_timing: bool = False
save_onnx: bool = False
finish_after: Optional[str] = None
initial_model: Optional[InitialModel] = None


class TournamentBenchmarkConfig(StrictBaseModel):
Expand Down
1 change: 0 additions & 1 deletion deep_quoridor/src/v2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def model_uploader(config: Config, every: str, model_id: str, wandb_run, shutdow

def train(config: Config):
batch_size = config.training.batch_size

alphazero_agent = create_alphazero(config, config.self_play.alphazero, overrides={"training_mode": True})

upload_model_thread = None
Expand Down