@@ -59,6 +59,18 @@ class TrainerTracker(object):
5959 worst_case : int | None = None
6060
6161
62+ def set_seed (seed : int ) -> None :
63+ np .random .seed (seed )
64+ torch .manual_seed (seed )
65+ torch .cuda .manual_seed (seed )
66+ torch .cuda .manual_seed_all (seed )
67+ torch .backends .cudnn .benchmark = False
68+ torch .backends .cudnn .deterministic = True
69+ random_seed (seed )
70+ np .random .seed (seed )
71+ environ ['PYTHONHASHSEED' ] = str (seed )
72+
73+
6274class Trainer (WithPaddingModule , WithNetwork , metaclass = ABCMeta ):
6375 def __init__ (self , trainer_folder : str | PathLike [str ], dataloader : DataLoader [tuple [torch .Tensor , torch .Tensor ]],
6476 validation_dataloader : DataLoader [tuple [torch .Tensor , torch .Tensor ]], * , recoverable : bool = True ,
@@ -197,15 +209,7 @@ def set_frontend(self, frontend: type[Frontend], *, path_to_secrets: str | PathL
197209 self ._frontend = frontend (load_secrets (path = path_to_secrets ) if path_to_secrets else load_secrets ())
198210
199211 def set_seed (self , seed : int ) -> None :
200- np .random .seed (seed )
201- torch .manual_seed (seed )
202- torch .cuda .manual_seed (seed )
203- torch .cuda .manual_seed_all (seed )
204- torch .backends .cudnn .benchmark = False
205- torch .backends .cudnn .deterministic = True
206- random_seed (seed )
207- np .random .seed (seed )
208- environ ['PYTHONHASHSEED' ] = str (seed )
212+ set_seed (seed )
209213 if self .initialized ():
210214 self .log (f"Set to manual seed { seed } " )
211215
0 commit comments