Skip to content

Commit fe3e17f

Browse files
committed
Refactored set_seed into a standalone function and updated imports.
1 parent 96bbc0d commit fe3e17f

2 files changed

Lines changed: 14 additions & 10 deletions

File tree

mipcandy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212
from mipcandy.profiler import ProfilerFrame, Profiler
1313
from mipcandy.run import config
1414
from mipcandy.sanity_check import num_trainable_params, model_complexity_info, SanityCheckResult, sanity_check
15-
from mipcandy.training import TrainerToolbox, Trainer
15+
from mipcandy.training import TrainerToolbox, Trainer, set_seed
1616
from mipcandy.types import Setting, Settings, Params, Transform, SupportedPredictant, Colormap, Device, Shape2d, \
1717
Shape3d, Shape, AmbiguousShape, Paddings2d, Paddings3d, Paddings

mipcandy/training.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ class TrainerTracker(object):
5959
worst_case: int | None = None
6060

6161

62+
def set_seed(seed: int) -> None:
63+
np.random.seed(seed)
64+
torch.manual_seed(seed)
65+
torch.cuda.manual_seed(seed)
66+
torch.cuda.manual_seed_all(seed)
67+
torch.backends.cudnn.benchmark = False
68+
torch.backends.cudnn.deterministic = True
69+
random_seed(seed)
70+
np.random.seed(seed)
71+
environ['PYTHONHASHSEED'] = str(seed)
72+
73+
6274
class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta):
6375
def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
6476
validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, recoverable: bool = True,
@@ -197,15 +209,7 @@ def set_frontend(self, frontend: type[Frontend], *, path_to_secrets: str | PathL
197209
self._frontend = frontend(load_secrets(path=path_to_secrets) if path_to_secrets else load_secrets())
198210

199211
def set_seed(self, seed: int) -> None:
200-
np.random.seed(seed)
201-
torch.manual_seed(seed)
202-
torch.cuda.manual_seed(seed)
203-
torch.cuda.manual_seed_all(seed)
204-
torch.backends.cudnn.benchmark = False
205-
torch.backends.cudnn.deterministic = True
206-
random_seed(seed)
207-
np.random.seed(seed)
208-
environ['PYTHONHASHSEED'] = str(seed)
212+
set_seed(seed)
209213
if self.initialized():
210214
self.log(f"Set to manual seed {seed}")
211215

0 commit comments

Comments
 (0)