@@ -26,11 +26,8 @@ def __init__(self, config: DictConfig) -> None:
2626 """Initialize the model service."""
2727 self .config_ = config
2828
29- random_config = config .get ("random" , {})
30- seed = random_config .get ("seed" , None )
31- fully_deterministic = random_config .get ("fully_deterministic" , False )
32-
33- if seed is not None :
29+ # If a random seed was specified in the configuration, set it for reproducibility.
30+ if (seed := config .get ("random" , {}).get ("seed" , None )) is not None :
3431 seed = int (seed )
3532 os .environ ["PYTHONHASHSEED" ] = str (seed )
3633 os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":4096:8"
@@ -39,7 +36,10 @@ def __init__(self, config: DictConfig) -> None:
3936 # If we are in fully deterministic mode, enable deterministic algorithms and
4037 # patch any known issues with them. We use warn_only=True to avoid segfaults on
4138 # unsupported operations.
42- if fully_deterministic :
39+ self .fully_deterministic = config .get ("random" , {}).get (
40+ "fully_deterministic" , False
41+ )
42+ if self .fully_deterministic :
4343 torch .use_deterministic_algorithms (True , warn_only = True ) # noqa: FBT003
4444 patch_interpolate_antialias ()
4545 log .warning (
@@ -204,21 +204,18 @@ def build_trainer(
204204 hydra .utils .instantiate (
205205 self .config ["train" ]["trainer" ],
206206 callbacks = extra_callbacks ,
207- deterministic = self .config .get ("random" , {}).get (
208- "fully_deterministic" , False
209- ),
207+ deterministic = self .fully_deterministic ,
210208 logger = extra_loggers ,
211209 ),
212210 )
213- # Check warn_only survived Lightning's deterministic setup
214- log .debug (
215- "deterministic_algorithms_enabled: %s" ,
216- torch .are_deterministic_algorithms_enabled (),
217- )
218- log .debug (
219- "warn_only_enabled: %s" ,
220- torch .is_deterministic_algorithms_warn_only_enabled (),
221- )
211+
212+ # Check that fully_deterministic is set correctly
213+ if self .fully_deterministic != torch .are_deterministic_algorithms_enabled ():
214+ log .warning (
215+ "fully_deterministic is set to %s but torch.are_deterministic_algorithms_enabled() is %s." ,
216+ self .fully_deterministic ,
217+ torch .are_deterministic_algorithms_enabled (),
218+ )
222219
223220 # Assign workers for data loading
224221 self .data_module .assign_workers (
0 commit comments