77from hydra_zen .third_party .pydantic import pydantic_parser
88from omegaconf import DictConfig , OmegaConf
99
10- from ml_project_template .utils import ConfigKeys , get_hydra_output_dir , logger , seed_everything
10+ from ml_project_template .utils import ConfigKeys , get_output_dir , logger
1111from ml_project_template .wandb import WandBRun
1212
1313
14- def pre_call (root_config : DictConfig , log_debug : bool = False ) -> None :
14+ def pre_call (root_config : DictConfig , seed_fn : Callable [[ int ], None ] | None = None , log_debug : bool = False ) -> None :
1515 """Logs the config, sets the seed and initializes a WandB run before config instantiation.
1616
1717 Args:
1818 root_config: Unresolved config.
19+ seed_fn: Function to use for seeding the run.
1920 log_debug: Whether to log the config, seed and output path.
2021 """
2122 if log_debug :
@@ -27,7 +28,10 @@ def pre_call(root_config: DictConfig, log_debug: bool = False) -> None:
2728 return
2829
2930 if (seed := config .get (ConfigKeys .SEED )) is not None :
30- seed_everything (seed )
31+ if seed_fn is None :
32+ raise ValueError ("No seeding function was set for the given seed." )
33+
34+ seed_fn (seed )
3135 logger .debug (f"Set seed to { seed } ." )
3236 else :
3337 logger .warning ("No seed was configured! Run may not be reproducible." )
@@ -37,7 +41,7 @@ def pre_call(root_config: DictConfig, log_debug: bool = False) -> None:
3741 else :
3842 logger .debug (f"Running config:\n { to_yaml (root_config )} " )
3943
40- output_path = get_hydra_output_dir ()
44+ output_path = get_output_dir ()
4145 logger .debug (f"Saving outputs in { output_path } " )
4246
4347 logger .setLevel (logging .INFO )
@@ -48,17 +52,18 @@ def pre_call(root_config: DictConfig, log_debug: bool = False) -> None:
4852 wandb .save (output_path / ".hydra/*" , base_path = output_path , policy = "now" )
4953
5054
51- def run (main_function : Callable , log_debug : bool = True ) -> None :
55+ def run (main_function : Callable , seed_fn : Callable [[ int ], None ] | None = None , log_debug : bool = True ) -> None :
5256 """Configure and run a given function using hydra-zen.
5357
5458 Args:
5559 main_function: Function to configure and run.
60+ seed_fn: Function to use for seeding the run.
5661 log_debug: Whether to log debug information from the `pre_call` function.
5762 """
5863 store .add_to_hydra_store ()
5964 zen (
6065 main_function ,
61- pre_call = partial (pre_call , log_debug = log_debug ),
66+ pre_call = partial (pre_call , seed_fn = seed_fn , log_debug = log_debug ),
6267 resolve_pre_call = False ,
6368 instantiation_wrapper = pydantic_parser ,
6469 ).hydra_main (
0 commit comments