Skip to content

Commit 9239215

Browse files
authored
Save model config to W&B (#225)
* 👽 Write the model_config to the files subdirectory and tell W&B to save it * ✅ Fix tests to deal with new config location * 🚚 Move get_wandb_run method into utils as it might be more widely useful
1 parent 4652ad9 commit 9239215

3 files changed

Lines changed: 30 additions & 16 deletions

File tree

icenet_mp/model_service.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@
88
from lightning import Callback, Trainer
99
from lightning.fabric.utilities import suggested_max_num_workers
1010
from lightning.pytorch.callbacks import ModelCheckpoint
11-
from lightning.pytorch.loggers import WandbLogger
1211
from omegaconf import DictConfig, OmegaConf
1312
from wandb.sdk.lib.runid import generate_id
14-
from wandb.wandb_run import Run
1513

1614
from icenet_mp.callbacks import UnconditionalCheckpoint
1715
from icenet_mp.data_loaders import CommonDataModule
1816
from icenet_mp.models.base_model import BaseModel
1917
from icenet_mp.types import SupportsMetadata
20-
from icenet_mp.utils import get_device_name, get_timestamp
18+
from icenet_mp.utils import get_device_name, get_timestamp, get_wandb_run
2119

2220
if TYPE_CHECKING:
2321
from lightning.pytorch.loggers import Logger as LightningLogger
@@ -79,7 +77,7 @@ def from_checkpoint(
7977
# Build a combined model configuration where the command line config takes
8078
# precedence except for the "model", "predict" and "train" keys which are
8179
# related to training the model.
82-
config_path = checkpoint_path.parent.parent / "model_config.yaml"
80+
config_path = checkpoint_path.parent.parent / "files" / "model_config.yaml"
8381
try:
8482
# Load the model configuration from the checkpoint directory
8583
ckpt_config = DictConfig(OmegaConf.load(config_path))
@@ -132,16 +130,12 @@ def model(self) -> BaseModel:
132130

133131
@property
134132
def run_directory(self) -> Path:
135-
"""Get run directory from wandb logger or generate one in the same format."""
133+
"""Get run directory from Wandb or generate one in the same format."""
136134
if not self.run_directory_:
137-
# Get the run directory from the WandbLogger if it exists
138-
for lightning_logger in self.trainer.loggers:
139-
if not isinstance(lightning_logger, WandbLogger):
140-
continue
141-
if not isinstance(experiment := lightning_logger.experiment, Run):
142-
continue
143-
self.run_directory_ = Path(experiment._settings.sync_dir)
144-
break
135+
# Get the run directory from Wandb if it exists
136+
wandb_run = get_wandb_run(self.trainer)
137+
if wandb_run:
138+
self.run_directory_ = Path(wandb_run._settings.sync_dir)
145139

146140
# Otherwise generate a new run directory
147141
if not self.run_directory_:
@@ -233,7 +227,10 @@ def configure_trainer(
233227
callback.dirpath = self.run_directory / "checkpoints"
234228

235229
# Save model config to the run directory
236-
OmegaConf.save(self.config, self.run_directory / "model_config.yaml")
230+
model_config_path = self.run_directory / "files" / "model_config.yaml"
231+
OmegaConf.save(self.config, model_config_path)
232+
if wandb_run := get_wandb_run(self.trainer):
233+
wandb_run.save(model_config_path, base_path=model_config_path.parent)
237234

238235
def evaluate(self) -> None:
239236
"""Evaluate a trained model."""

icenet_mp/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import numpy as np
44
import torch
5+
from lightning import Trainer
6+
from lightning.pytorch.loggers import WandbLogger
7+
from wandb.wandb_run import Run
58

69

710
def datetime_from_npdatetime(dt: np.datetime64) -> datetime:
@@ -31,6 +34,16 @@ def get_timestamp() -> str:
3134
return datetime.now(tz=UTC).strftime(r"%Y%m%d_%H%M%S")
3235

3336

37+
def get_wandb_run(trainer: Trainer) -> Run | None:
38+
"""Get the Wandb Run instance if it exists."""
39+
for lightning_logger in trainer.loggers:
40+
if isinstance(lightning_logger, WandbLogger) and isinstance(
41+
experiment := lightning_logger.experiment, Run
42+
):
43+
return experiment
44+
return None
45+
46+
3447
def normalise_date(np_datetime: np.datetime64) -> np.datetime64:
3548
"""Normalise a datetime to midnight."""
3649
dt: datetime = np_datetime.astype("datetime64[ms]").astype(datetime)

tests/test_model_service.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def test_from_checkpoint_loads_model(
6363
checkpoint_path = checkpoints_dir / "model.ckpt"
6464
checkpoint_path.write_text("checkpoint")
6565

66-
OmegaConf.save(cfg_model_service, tmp_path / "model_config.yaml")
66+
files_dir = tmp_path / "files"
67+
files_dir.mkdir(parents=True)
68+
OmegaConf.save(cfg_model_service, files_dir / "model_config.yaml")
6769

6870
with pytest.MonkeyPatch.context() as mp:
6971
mp.setattr(
@@ -83,7 +85,9 @@ def test_from_checkpoint_config_overloads(
8385
checkpoint_path = checkpoints_dir / "model.ckpt"
8486
checkpoint_path.write_text("checkpoint")
8587

86-
OmegaConf.save(cfg_model_service, tmp_path / "model_config.yaml")
88+
files_dir = tmp_path / "files"
89+
files_dir.mkdir(parents=True)
90+
OmegaConf.save(cfg_model_service, files_dir / "model_config.yaml")
8791

8892
with pytest.MonkeyPatch.context() as mp:
8993
mp.setattr(

0 commit comments

Comments
 (0)