Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion icenet_mp/config/base_isambardai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ defaults:
- base
- _self_

base_path: /projects/u5gf/seaice/base/
base_path: /projects/public/u6iz/shared/base/
3 changes: 2 additions & 1 deletion icenet_mp/losses/weighted_mse_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
**kwargs: Keyword arguments passed to torch.nn.MSELoss.

"""
kwargs["reduction"] = "none"
super().__init__(*args, **kwargs)

def forward( # type: ignore[override]
Expand All @@ -43,5 +44,5 @@ def forward( # type: ignore[override]
targets = targets.squeeze()
sample_weights = sample_weights.squeeze()

loss = super().forward(100 * y_hat, 100 * targets) * sample_weights
loss = super().forward(y_hat, targets) * sample_weights
return loss.mean()
24 changes: 22 additions & 2 deletions icenet_mp/model_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path, PosixPath
from typing import cast

Expand All @@ -23,8 +24,13 @@ class ModelService:
def __init__(self, config: DictConfig) -> None:
"""Initialize the model service."""
self.config_ = config
if seed := config.get("seed", None):
seed_everything(int(seed), workers=True)
if (seed := config.get("seed", None)) is not None:
seed = int(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference https://docs.nvidia.com/cuda/cublas/ and explain what this means

seed_everything(seed, workers=True)
torch.use_deterministic_algorithms(True, warn_only=True) # noqa: FBT003

self.data_module_: CommonDataModule | None = None
self.model_: BaseModel | None = None

Expand Down Expand Up @@ -184,6 +190,20 @@ def build_trainer(
)
),
)
# Re-apply warn_only since Lightning may override it when deterministic=True
if self.config.get("seed", None):
torch.use_deterministic_algorithms(True, warn_only=True) # noqa: FBT003

# Check warn_only survived Lightning's deterministic setup
log.debug(
"deterministic_algorithms_enabled: %s",
torch.are_deterministic_algorithms_enabled(),
)
log.debug(
"warn_only_enabled: %s",
torch.is_deterministic_algorithms_warn_only_enabled(),
)

# Assign workers for data loading
self.data_module.assign_workers(suggested_max_num_workers(trainer.num_devices))

Expand Down
Loading