Skip to content

Commit d465c07

Browse files
authored
Merge pull request #341 from adamantivm/jac/sample_onnx_resnet_model
Updates ONNX save code and experiment to support ResNet
2 parents 1c94373 + cc3c02b commit d465c07

File tree

14 files changed

+233
-231
lines changed

14 files changed

+233
-231
lines changed

deep_quoridor/MODEL_SAVE_OPTIONS.md

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## Summary
44

5-
Added configuration options to save AlphaZero models in PyTorch and/or ONNX formats during training.
5+
Added configuration options to save AlphaZero models in ONNX format during training. PyTorch format (.pt files) is always saved.
66

77
## Configuration Options
88

@@ -12,13 +12,12 @@ Add these parameters to the `training` section of your YAML configuration:
1212
training:
1313
# ... other training parameters ...
1414
model_save_timing: false # Set to true to print timing information for model saving
15-
save_pytorch: true # Save models in PyTorch format (.pt files) - DEFAULT
1615
save_onnx: false # Save models in ONNX format (.onnx files) - DEFAULT
1716
```
1817
1918
## Default Behavior
2019
21-
- **PyTorch format**: Enabled by default (`save_pytorch: true`)
20+
- **PyTorch format**: Always enabled (cannot be disabled)
2221
- **ONNX format**: Disabled by default (`save_onnx: false`)
2322
- **Timing output**: Disabled by default (`model_save_timing: false`)
2423

@@ -29,31 +28,22 @@ This ensures backward compatibility with existing configurations.
2928
### Example 1: Save only PyTorch format (default)
3029
```yaml
3130
training:
32-
save_pytorch: true
3331
save_onnx: false
3432
```
3533

3634
### Example 2: Save both PyTorch and ONNX formats
3735
```yaml
3836
training:
39-
save_pytorch: true
4037
save_onnx: true
4138
model_save_timing: true # See timing for both formats
4239
```
4340

44-
### Example 3: Save only ONNX format
45-
```yaml
46-
training:
47-
save_pytorch: false
48-
save_onnx: true
49-
```
50-
5141
## Test Configurations
5242

5343
Two test configurations are provided:
5444

5545
1. **`experiments/test_model_save_timing.yaml`** - Basic test with PyTorch only
56-
2. **`experiments/test_onnx_export.yaml`** - Test with both PyTorch and ONNX export enabled
46+
2. **`experiments/test_onnx_export.yaml`** - Test with ONNX export enabled
5747

5848
## ONNX Export Details
5949

@@ -67,11 +57,11 @@ The ONNX export includes:
6757
## Files Modified
6858

6959
1. **`src/v2/config.py`**
70-
- Added `save_pytorch`, `save_onnx`, and `model_save_timing` to `TrainingConfig`
60+
- Added `save_onnx` and `model_save_timing` to `TrainingConfig`
7161

7262
2. **`src/v2/trainer.py`**
73-
- Updated initial model save (model_0) to support both formats
74-
- Updated training loop model saves to support both formats
63+
- Updated initial model save (model_0) to support ONNX format
64+
- Updated training loop model saves to support ONNX format
7565
- Enhanced timing output to show which formats were saved
7666

7767
3. **`src/agents/alphazero/alphazero.py`**
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
extend: base.yaml
2+
run_id: sample-onnx-export-$DATETIME
3+
4+
alphazero:
5+
mcts_n: 50
6+
wandb: null # Disable wandb for testing
7+
self_play:
8+
num_workers: 1
9+
parallel_games: 2
10+
alphazero:
11+
mcts_noise_epsilon: 0.25
12+
training:
13+
finish_after: 10 models
14+
model_save_timing: true
15+
save_onnx: true
16+
benchmarks: []

deep_quoridor/experiments/test_model_save_timing.yaml

Lines changed: 0 additions & 26 deletions
This file was deleted.

deep_quoridor/experiments/test_onnx_export.yaml

Lines changed: 0 additions & 26 deletions
This file was deleted.

deep_quoridor/src/agents/alphazero/alphazero.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -419,17 +419,24 @@ def save_model(self, path):
419419
def save_model_onnx(self, path):
420420
"""Export the model to ONNX format."""
421421
import torch.onnx
422-
422+
423423
# Create directory for saving models if it doesn't exist
424424
os.makedirs(Path(path).absolute().parents[0], exist_ok=True)
425-
425+
426426
# Set the network to evaluation mode
427427
self.evaluator.network.eval()
428-
428+
429429
# Create a dummy input tensor with the correct shape
430-
# The input size is determined by the network's input_size attribute
431-
dummy_input = torch.randn(1, self.evaluator.network.input_size, device=self.device)
432-
430+
# The shape depends on the network type
431+
network = self.evaluator.network
432+
if hasattr(network, "__class__") and network.__class__.__name__ == "ResnetNetwork":
433+
# ResNet expects input of shape (batch_size, 5, input_size, input_size)
434+
# NOTE: input_size is board_size * 2 + 3, which is the dimension of the combined grid input, not the original board size
435+
dummy_input = torch.randn(1, 5, network.input_size, network.input_size, device=self.device)
436+
else:
437+
# MLP expects input of shape (batch_size, input_size)
438+
dummy_input = torch.randn(1, network.input_size, device=self.device)
439+
433440
# Export the model with opset 17 (widely supported, avoids version conversion issues)
434441
torch.onnx.export(
435442
self.evaluator.network,
@@ -445,6 +452,7 @@ def save_model_onnx(self, path):
445452
"policy_logits": {0: "batch_size"},
446453
"value": {0: "batch_size"},
447454
},
455+
external_data=False, # Don't use external data format for simplicity; model should be small enough to fit in a single file
448456
)
449457
print(f"AlphaZero model exported to ONNX at {path}")
450458

deep_quoridor/src/v2/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class TrainingConfig(StrictBaseModel):
7676
weight_decay: float
7777
replay_buffer_size: int
7878
model_save_timing: bool = False
79-
save_pytorch: bool = True
8079
save_onnx: bool = False
8180
finish_after: Optional[str] = None
8281

deep_quoridor/src/v2/trainer.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,10 @@ def train(config: Config):
6464
wandb_run = MockWandb()
6565

6666
# Save initial model (model_0)
67-
if config.training.save_pytorch:
68-
filename = config.paths.checkpoints / "model_0.pt"
69-
alphazero_agent.save_model(filename)
70-
LatestModel.write(config, str(filename), 0)
71-
67+
filename = config.paths.checkpoints / "model_0.pt"
68+
alphazero_agent.save_model(filename)
69+
LatestModel.write(config, str(filename), 0)
70+
7271
if config.training.save_onnx:
7372
onnx_filename = config.paths.checkpoints / "model_0.onnx"
7473
alphazero_agent.save_model_onnx(onnx_filename)
@@ -166,29 +165,27 @@ def train(config: Config):
166165
)
167166

168167
Timer.start("save-model")
169-
170-
# Save in PyTorch format if enabled
171-
if config.training.save_pytorch:
172-
new_model_filename = config.paths.checkpoints / f"model_{model_version}.pt"
173-
alphazero_agent.save_model(new_model_filename)
174-
LatestModel.write(config, str(new_model_filename), model_version)
175-
168+
169+
# Save in PyTorch format
170+
new_model_filename = config.paths.checkpoints / f"model_{model_version}.pt"
171+
alphazero_agent.save_model(new_model_filename)
172+
LatestModel.write(config, str(new_model_filename), model_version)
173+
176174
# Save in ONNX format if enabled
177175
if config.training.save_onnx:
178176
onnx_model_filename = config.paths.checkpoints / f"model_{model_version}.onnx"
179177
alphazero_agent.save_model_onnx(onnx_model_filename)
180-
178+
181179
time_save_model = Timer.finish("save-model")
182-
180+
183181
if config.training.model_save_timing:
184182
formats = []
185-
if config.training.save_pytorch:
186-
formats.append("PyTorch")
183+
formats.append("PyTorch")
187184
if config.training.save_onnx:
188185
formats.append("ONNX")
189186
format_str = " and ".join(formats) if formats else "no format"
190187
print(f"Saving model ({format_str}) took {time_save_model:.4f}s")
191-
188+
192189
model_version += 1
193190

194191
ShutdownSignal.signal(config)

deep_quoridor/test/config_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ def test_override_boolean_true(config_file):
5151
assert config.training.model_save_timing is True
5252

5353

54-
def test_override_boolean_false(config_file):
55-
config = load_user_config(config_file, overrides=["training.save_pytorch=false"])
56-
assert config.training.save_pytorch is False
57-
58-
5954
def test_override_int(config_file):
6055
config = load_user_config(config_file, overrides=["alphazero.mcts_n=500"])
6156
assert config.alphazero.mcts_n == 500
1.2 MB
Binary file not shown.
403 KB
Binary file not shown.

0 commit comments

Comments
 (0)