|
8 | 8 | from lightning import Callback, Trainer |
9 | 9 | from lightning.fabric.utilities import suggested_max_num_workers |
10 | 10 | from lightning.pytorch.callbacks import ModelCheckpoint |
11 | | -from lightning.pytorch.loggers import WandbLogger |
12 | 11 | from omegaconf import DictConfig, OmegaConf |
13 | 12 | from wandb.sdk.lib.runid import generate_id |
14 | | -from wandb.wandb_run import Run |
15 | 13 |
|
16 | 14 | from icenet_mp.callbacks import UnconditionalCheckpoint |
17 | 15 | from icenet_mp.data_loaders import CommonDataModule |
18 | 16 | from icenet_mp.models.base_model import BaseModel |
19 | 17 | from icenet_mp.types import SupportsMetadata |
20 | | -from icenet_mp.utils import get_device_name, get_timestamp |
| 18 | +from icenet_mp.utils import get_device_name, get_timestamp, get_wandb_run |
21 | 19 |
|
22 | 20 | if TYPE_CHECKING: |
23 | 21 | from lightning.pytorch.loggers import Logger as LightningLogger |
@@ -79,7 +77,7 @@ def from_checkpoint( |
79 | 77 | # Build a combined model configuration where the command line config takes |
80 | 78 | # precedence except for the "model", "predict" and "train" keys which are |
81 | 79 | # related to training the model. |
82 | | - config_path = checkpoint_path.parent.parent / "model_config.yaml" |
| 80 | + config_path = checkpoint_path.parent.parent / "files" / "model_config.yaml" |
83 | 81 | try: |
84 | 82 | # Load the model configuration from the checkpoint directory |
85 | 83 | ckpt_config = DictConfig(OmegaConf.load(config_path)) |
@@ -132,16 +130,12 @@ def model(self) -> BaseModel: |
132 | 130 |
|
133 | 131 | @property |
134 | 132 | def run_directory(self) -> Path: |
135 | | - """Get run directory from wandb logger or generate one in the same format.""" |
| 133 | + """Get run directory from Wandb or generate one in the same format.""" |
136 | 134 | if not self.run_directory_: |
137 | | - # Get the run directory from the WandbLogger if it exists |
138 | | - for lightning_logger in self.trainer.loggers: |
139 | | - if not isinstance(lightning_logger, WandbLogger): |
140 | | - continue |
141 | | - if not isinstance(experiment := lightning_logger.experiment, Run): |
142 | | - continue |
143 | | - self.run_directory_ = Path(experiment._settings.sync_dir) |
144 | | - break |
| 135 | + # Get the run directory from Wandb if it exists |
| 136 | + wandb_run = get_wandb_run(self.trainer) |
| 137 | + if wandb_run: |
| 138 | + self.run_directory_ = Path(wandb_run._settings.sync_dir) |
145 | 139 |
|
146 | 140 | # Otherwise generate a new run directory |
147 | 141 | if not self.run_directory_: |
@@ -233,7 +227,10 @@ def configure_trainer( |
233 | 227 | callback.dirpath = self.run_directory / "checkpoints" |
234 | 228 |
|
235 | 229 | # Save model config to the run directory |
236 | | - OmegaConf.save(self.config, self.run_directory / "model_config.yaml") |
| 230 | + model_config_path = self.run_directory / "files" / "model_config.yaml" |
| 231 | + OmegaConf.save(self.config, model_config_path) |
| 232 | + if wandb_run := get_wandb_run(self.trainer): |
| 233 | + wandb_run.save(model_config_path, base_path=model_config_path.parent) |
237 | 234 |
|
238 | 235 | def evaluate(self) -> None: |
239 | 236 | """Evaluate a trained model.""" |
|
0 commit comments