Skip to content

Commit 69757b1

Browse files
Add seed (#255)
* 🍱 Add seed for everything and deterministic and its config * 🔨 Edit config * 🔨 Fix linter errors * 🔨 Refactor codes, move from cli train to ModelService * Apply suggestion from @jemrobinson Co-authored-by: James Robinson <james.em.robinson@gmail.com> --------- Co-authored-by: James Robinson <james.em.robinson@gmail.com>
1 parent 95415f3 commit 69757b1

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

icenet_mp/config/base.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ defaults:
1010
- _self_
1111

1212
base_path: ../base
13+
14+
seed: 555

icenet_mp/model_service.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import hydra
77
import torch
8-
from lightning import Callback, Trainer
8+
from lightning import Callback, Trainer, seed_everything
99
from lightning.fabric.utilities import suggested_max_num_workers
1010
from lightning.pytorch.callbacks import ModelCheckpoint
1111
from omegaconf import DictConfig, OmegaConf
@@ -27,6 +27,8 @@ class ModelService:
2727
def __init__(self, config: DictConfig) -> None:
2828
"""Initialize the model service."""
2929
self.config_ = config
30+
if seed := config.get("seed", None):
31+
seed_everything(int(seed), workers=True)
3032
self.data_module_: CommonDataModule | None = None
3133
self.model_: BaseModel | None = None
3234
self.trainer_: Trainer | None = None
@@ -165,6 +167,7 @@ def trainer(self) -> Trainer:
165167
{
166168
"callbacks": self.extra_callbacks_,
167169
"logger": self.extra_loggers_,
170+
"deterministic": self.config.get("seed", None) is not None,
168171
},
169172
**self.config["train"]["trainer"],
170173
)

0 commit comments

Comments
 (0)