Skip to content

Commit b545ae1

Browse files
committed
🎨 Tidy up use of fully_deterministic
1 parent 9a313c9 commit b545ae1

1 file changed

Lines changed: 15 additions & 18 deletions

File tree

icenet_mp/model_service.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@ def __init__(self, config: DictConfig) -> None:
2626
"""Initialize the model service."""
2727
self.config_ = config
2828

29-
random_config = config.get("random", {})
30-
seed = random_config.get("seed", None)
31-
fully_deterministic = random_config.get("fully_deterministic", False)
32-
33-
if seed is not None:
29+
# If a random seed was specified in the configuration, set it for reproducibility.
30+
if (seed := config.get("random", {}).get("seed", None)) is not None:
3431
seed = int(seed)
3532
os.environ["PYTHONHASHSEED"] = str(seed)
3633
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
@@ -39,7 +36,10 @@ def __init__(self, config: DictConfig) -> None:
3936
# If we are in fully deterministic mode, enable deterministic algorithms and
4037
# patch any known issues with them. We use warn_only=True to avoid segfaults on
4138
# unsupported operations.
42-
if fully_deterministic:
39+
self.fully_deterministic = config.get("random", {}).get(
40+
"fully_deterministic", False
41+
)
42+
if self.fully_deterministic:
4343
torch.use_deterministic_algorithms(True, warn_only=True) # noqa: FBT003
4444
patch_interpolate_antialias()
4545
log.warning(
@@ -204,21 +204,18 @@ def build_trainer(
204204
hydra.utils.instantiate(
205205
self.config["train"]["trainer"],
206206
callbacks=extra_callbacks,
207-
deterministic=self.config.get("random", {}).get(
208-
"fully_deterministic", False
209-
),
207+
deterministic=self.fully_deterministic,
210208
logger=extra_loggers,
211209
),
212210
)
213-
# Check warn_only survived Lightning's deterministic setup
214-
log.debug(
215-
"deterministic_algorithms_enabled: %s",
216-
torch.are_deterministic_algorithms_enabled(),
217-
)
218-
log.debug(
219-
"warn_only_enabled: %s",
220-
torch.is_deterministic_algorithms_warn_only_enabled(),
221-
)
211+
212+
# Check that fully_deterministic is set correctly
213+
if self.fully_deterministic != torch.are_deterministic_algorithms_enabled():
214+
log.warning(
215+
"fully_deterministic is set to %s but torch.are_deterministic_algorithms_enabled() is %s.",
216+
self.fully_deterministic,
217+
torch.are_deterministic_algorithms_enabled(),
218+
)
222219

223220
# Assign workers for data loading
224221
self.data_module.assign_workers(

0 commit comments

Comments
 (0)