Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion icenet_mp/config/base_isambardai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ defaults:
- base
- _self_

base_path: /projects/u5gf/seaice/base/
base_path: /projects/public/u6iz/shared/base/
43 changes: 42 additions & 1 deletion icenet_mp/model_service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import os
from pathlib import Path, PosixPath
from typing import cast

import hydra
import torch
import torch.nn.functional as F # noqa: N812
from lightning import Callback, Trainer, seed_everything
from lightning.fabric.utilities import suggested_max_num_workers
from lightning.pytorch.callbacks import ModelCheckpoint
Expand All @@ -19,12 +21,41 @@
log = logging.getLogger(__name__)


class _DeterministicInterpolate:
"""Monkey-patch F.interpolate to strip antialias=True for deterministic CUDA backward.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is something we want to do. We previously needed the antialias argument to avoid artifacts when resizing.


upsample_bilinear2d_aa has no deterministic CUDA backward pass, so we strip
the antialias argument globally to ensure deterministic behaviour.
"""

_applied = False # class-level flag

def __init__(self) -> None:
if _DeterministicInterpolate._applied:
return
self._original = F.interpolate
F.interpolate = self # type: ignore[assignment]
_DeterministicInterpolate._applied = True

def __call__(
self, tensor: torch.Tensor, *args: object, **kwargs: object
) -> torch.Tensor:
kwargs.pop("antialias", None)
return self._original(tensor, *args, **kwargs) # type: ignore[arg-type]


class ModelService:
def __init__(self, config: DictConfig) -> None:
"""Initialize the model service."""
self.config_ = config
if seed := config.get("seed", None):
seed_everything(int(seed), workers=True)
seed = int(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference https://docs.nvidia.com/cuda/cublas/ and explain what this means

seed_everything(seed, workers=True)
torch.use_deterministic_algorithms(True, warn_only=True) # noqa: FBT003
_DeterministicInterpolate()

self.data_module_: CommonDataModule | None = None
self.model_: BaseModel | None = None

Expand Down Expand Up @@ -184,6 +215,16 @@ def build_trainer(
)
),
)
# Check warn_only survived Lightning's deterministic setup
log.debug(
"deterministic_algorithms_enabled: %s",
torch.are_deterministic_algorithms_enabled(),
)
log.debug(
"warn_only_enabled: %s",
torch.is_deterministic_algorithms_warn_only_enabled(),
)

# Assign workers for data loading
self.data_module.assign_workers(suggested_max_num_workers(trainer.num_devices))

Expand Down
Loading