diff --git a/icenet_mp/config/base_isambardai.yaml b/icenet_mp/config/base_isambardai.yaml index 62f73011a..d47d1d808 100644 --- a/icenet_mp/config/base_isambardai.yaml +++ b/icenet_mp/config/base_isambardai.yaml @@ -2,4 +2,4 @@ defaults: - base - _self_ -base_path: /projects/u5gf/seaice/base/ +base_path: /projects/public/u6iz/shared/base/ diff --git a/icenet_mp/losses/weighted_mse_loss.py b/icenet_mp/losses/weighted_mse_loss.py index 880ef24a2..84f8be6ea 100644 --- a/icenet_mp/losses/weighted_mse_loss.py +++ b/icenet_mp/losses/weighted_mse_loss.py @@ -20,6 +20,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: **kwargs: Keyword arguments passed to torch.nn.MSELoss. """ + kwargs["reduction"] = "none" super().__init__(*args, **kwargs) def forward( # type: ignore[override] @@ -43,5 +44,5 @@ def forward( # type: ignore[override] targets = targets.squeeze() sample_weights = sample_weights.squeeze() - loss = super().forward(100 * y_hat, 100 * targets) * sample_weights + loss = super().forward(y_hat, targets) * sample_weights return loss.mean() diff --git a/icenet_mp/model_service.py b/icenet_mp/model_service.py index 9d60cc0e5..e456a3272 100644 --- a/icenet_mp/model_service.py +++ b/icenet_mp/model_service.py @@ -1,4 +1,5 @@ import logging +import os from pathlib import Path, PosixPath from typing import cast @@ -23,8 +24,13 @@ class ModelService: def __init__(self, config: DictConfig) -> None: """Initialize the model service.""" self.config_ = config - if seed := config.get("seed", None): - seed_everything(int(seed), workers=True) + if (seed := config.get("seed", None)) is not None: + seed = int(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + seed_everything(seed, workers=True) + torch.use_deterministic_algorithms(True, warn_only=True) # noqa: FBT003 + self.data_module_: CommonDataModule | None = None self.model_: BaseModel | None = None @@ -184,6 +190,20 @@ def build_trainer( ) ), ) + # Re-apply warn_only since Lightning may override it when deterministic=True + if self.config.get("seed", None): + torch.use_deterministic_algorithms(True, warn_only=True) # noqa: FBT003 + + # Check warn_only survived Lightning's deterministic setup + log.debug( + "deterministic_algorithms_enabled: %s", + torch.are_deterministic_algorithms_enabled(), + ) + log.debug( + "warn_only_enabled: %s", + torch.is_deterministic_algorithms_warn_only_enabled(), + ) + # Assign workers for data loading self.data_module.assign_workers(suggested_max_num_workers(trainer.num_devices))