Skip to content

Commit c828c56

Browse files
Merge pull request #340 from jonbinney/overrides
Override config from the command line
2 parents bcfcc55 + 6b1838b commit c828c56

File tree

4 files changed

+189
-8
lines changed

4 files changed

+189
-8
lines changed

deep_quoridor/src/train_v2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
parser.add_argument("-r", "--runs-dir", type=str, default=None, help="Directory for runs")
1313
# TODO: implement this
1414
# parser.add_argument("-c", "--continue", dest="continue_run", action="store_true", help="Continue an existing run")
15-
# parser.add_argument(
16-
# "-o", "--overrides", nargs="*", help="Configuration overrides (e.g., run_id=my_run alphazero.mcts_n=250)"
17-
# )
15+
parser.add_argument(
16+
"-o", "--overrides", nargs="*", help="Configuration overrides (e.g., run_id=my_run alphazero.mcts_n=250)"
17+
)
1818

1919
args = parser.parse_args()
2020

2121
runs_dir = args.runs_dir if args.runs_dir is not None else str(Path(__file__).parent.parent)
2222

23-
config = load_config_and_setup_run(args.config_file, runs_dir)
23+
config = load_config_and_setup_run(args.config_file, runs_dir, overrides=args.overrides)
2424
mp.set_start_method("spawn", force=True)
2525

2626
# Make sure we don't have the shutdown signal from a previous run

deep_quoridor/src/v2/TODO.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# V1 Parity
22

33
- Replay buffer length: righ now we're not rolling out old games to respect the length
4-
- Overrides from the command line
54
- Continuation
65

76
# Other improvements and new features

deep_quoridor/src/v2/config.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,85 @@ def _load_config_data(file: str) -> dict:
213213
return data
214214

215215

216-
def load_user_config(file: str) -> UserConfig:
216+
def _parse_override_value(value: str):
217+
"""Parse a string value into an appropriate Python type."""
218+
if value.lower() == "none":
219+
return None
220+
if value.lower() == "true":
221+
return True
222+
if value.lower() == "false":
223+
return False
224+
if value.startswith("[") and value.endswith("]"):
225+
inner = value[1:-1].strip()
226+
if not inner:
227+
return []
228+
return [_parse_override_value(item.strip()) for item in inner.split(",")]
229+
try:
230+
return int(value)
231+
except ValueError:
232+
pass
233+
try:
234+
return float(value)
235+
except ValueError:
236+
pass
237+
return value
238+
239+
240+
def _as_index(part: str) -> int:
241+
try:
242+
return int(part)
243+
except ValueError:
244+
raise ValueError(f"Expected a numeric index for list, got '{part}'")
245+
246+
247+
def _ensure_and_navigate(target, part: str):
248+
"""Navigate into an intermediate path part, creating a dict if the key is missing."""
249+
if isinstance(target, list):
250+
return target[_as_index(part)]
251+
if part not in target:
252+
target[part] = {}
253+
return target[part]
254+
255+
256+
def _set_value(target, part: str, value):
257+
"""Set a value on a dict key or list index."""
258+
if isinstance(target, list):
259+
target[_as_index(part)] = value
260+
else:
261+
target[part] = value
262+
263+
264+
def _apply_overrides(data: dict, overrides: list[str]) -> dict:
265+
"""Apply dotted-key overrides (e.g. 'alphazero.mcts_n=250', 'wandb=None') to a config dict.
266+
267+
Supports numeric indices for lists: 'benchmarks.0.every=5m'
268+
"""
269+
for override in overrides:
270+
if "=" not in override:
271+
raise ValueError(f"Invalid override format '{override}', expected 'key=value'")
272+
key, value = override.split("=", 1)
273+
parts = key.split(".")
274+
parsed_value = _parse_override_value(value)
275+
276+
target = data
277+
for part in parts[:-1]:
278+
target = _ensure_and_navigate(target, part)
279+
_set_value(target, parts[-1], parsed_value)
280+
281+
return data
282+
283+
284+
def load_user_config(file: str, overrides: list[str] | None = None) -> UserConfig:
217285
data = _load_config_data(file)
286+
if overrides:
287+
_apply_overrides(data, overrides)
218288
return UserConfig.model_validate(data)
219289

220290

221-
def load_config_and_setup_run(file: str, base_dir: str, create_dirs: bool = True) -> Config:
222-
user_config = load_user_config(file)
291+
def load_config_and_setup_run(
292+
file: str, base_dir: str, overrides: list[str] | None = None, create_dirs: bool = True
293+
) -> Config:
294+
user_config = load_user_config(file, overrides=overrides)
223295
config = Config.from_user(user_config, base_dir, create_dirs=create_dirs)
224296

225297
config_filename = config.paths.config_file

deep_quoridor/test/config_test.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import pytest
2+
import yaml
3+
from v2.config import load_user_config
4+
5+
EXAMPLE_CONFIG = {
6+
"run_id": "test-run",
7+
"quoridor": {"board_size": 5, "max_walls": 3, "max_steps": 50},
8+
"alphazero": {"network": {"type": "mlp"}, "mcts_n": 300, "mcts_c_puct": 1.2},
9+
"wandb": {"project": "example", "upload_model": {"every": "20 models", "when_max": ["raw_win_perc", "elo_score"]}},
10+
"self_play": {"num_workers": 2, "parallel_games": 8, "alphazero": {"mcts_noise_epsilon": 0.25}},
11+
"training": {
12+
"games_per_training_step": 25.0,
13+
"learning_rate": 0.001,
14+
"batch_size": 256,
15+
"weight_decay": 0.0001,
16+
"replay_buffer_size": 1000000,
17+
},
18+
"benchmarks": [
19+
{
20+
"every": "10 models",
21+
"jobs": [
22+
{"type": "tournament", "prefix": "raw", "times": 10, "opponents": ["random", "greedy"]},
23+
{"type": "dumb_score", "prefix": "raw"},
24+
],
25+
},
26+
],
27+
}
28+
29+
30+
@pytest.fixture
31+
def config_file(tmp_path):
32+
path = tmp_path / "config.yaml"
33+
path.write_text(yaml.safe_dump(EXAMPLE_CONFIG, sort_keys=False))
34+
return str(path)
35+
36+
37+
def test_no_overrides(config_file):
38+
config = load_user_config(config_file)
39+
assert config.wandb is not None
40+
assert config.wandb.project == "example"
41+
assert config.training.learning_rate == 0.001
42+
43+
44+
def test_override_none(config_file):
45+
config = load_user_config(config_file, overrides=["wandb=None"])
46+
assert config.wandb is None
47+
48+
49+
def test_override_boolean_true(config_file):
50+
config = load_user_config(config_file, overrides=["training.model_save_timing=True"])
51+
assert config.training.model_save_timing is True
52+
53+
54+
def test_override_boolean_false(config_file):
55+
config = load_user_config(config_file, overrides=["training.save_pytorch=false"])
56+
assert config.training.save_pytorch is False
57+
58+
59+
def test_override_int(config_file):
60+
config = load_user_config(config_file, overrides=["alphazero.mcts_n=500"])
61+
assert config.alphazero.mcts_n == 500
62+
63+
64+
def test_override_float(config_file):
65+
config = load_user_config(config_file, overrides=["training.learning_rate=0.01"])
66+
assert config.training.learning_rate == 0.01
67+
68+
69+
def test_override_string(config_file):
70+
config = load_user_config(config_file, overrides=["run_id=my-custom-run"])
71+
assert config.run_id == "my-custom-run"
72+
73+
74+
def test_override_list(config_file):
75+
config = load_user_config(config_file, overrides=["wandb.upload_model.when_max=[dumb_score,tournament]"])
76+
assert config.wandb.upload_model.when_max == ["dumb_score", "tournament"]
77+
78+
79+
def test_override_empty_list(config_file):
80+
config = load_user_config(config_file, overrides=["wandb.upload_model.when_max=[]"])
81+
assert config.wandb.upload_model.when_max == []
82+
83+
84+
def test_override_list_index(config_file):
85+
config = load_user_config(config_file, overrides=["benchmarks.0.every=5 models"])
86+
assert config.benchmarks[0].every == "5 models"
87+
88+
89+
def test_override_nested_list_index(config_file):
90+
config = load_user_config(config_file, overrides=["benchmarks.0.jobs.0.times=20"])
91+
assert config.benchmarks[0].jobs[0].times == 20
92+
93+
94+
def test_multiple_overrides(config_file):
95+
config = load_user_config(
96+
config_file, overrides=["alphazero.mcts_n=100", "training.learning_rate=0.05", "wandb=None"]
97+
)
98+
assert config.alphazero.mcts_n == 100
99+
assert config.training.learning_rate == 0.05
100+
assert config.wandb is None
101+
102+
103+
def test_invalid_override_format(config_file):
104+
with pytest.raises(ValueError, match="Invalid override format"):
105+
load_user_config(config_file, overrides=["no_equals_sign"])
106+
107+
108+
def test_invalid_key_rejected_by_pydantic(config_file):
109+
with pytest.raises(Exception):
110+
load_user_config(config_file, overrides=["nonexistent_key=value"])

0 commit comments

Comments
 (0)