diff --git a/fme/core/logging_utils.py b/fme/core/logging_utils.py index 9763062d..e1654fe8 100644 --- a/fme/core/logging_utils.py +++ b/fme/core/logging_utils.py @@ -36,6 +36,8 @@ class LoggingConfig: log_to_wandb: Whether to log to Weights & Biases. log_format: Format of the log messages. level: Sets the logging level. + wandb_dir_in_experiment_dir: Whether to create the wandb_dir in the + experiment_dir or in local /tmp (default False). """ project: str = "ace" @@ -45,6 +47,7 @@ class LoggingConfig: log_to_wandb: bool = True log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" level: str | int = logging.INFO + wandb_dir_in_experiment_dir: bool = False def __post_init__(self): self._dist = Distributed.get_instance() @@ -76,7 +79,6 @@ def configure_wandb( self, config: Mapping[str, Any], env_vars: Mapping[str, Any] | None = None, - wandb_dir: str | None = DEFAULT_TMP_DIR, resumable: bool = True, resume: Any = None, **kwargs, @@ -94,6 +96,13 @@ def configure_wandb( ) elif env_vars is not None: config_copy["environment"] = env_vars + + experiment_dir = config["experiment_dir"] + if self.wandb_dir_in_experiment_dir: + wandb_dir = experiment_dir + else: + wandb_dir = DEFAULT_TMP_DIR + # must ensure wandb.configure is called before wandb.init wandb = WandB.get_instance() wandb.configure(log_to_wandb=self.log_to_wandb) @@ -101,7 +110,7 @@ def configure_wandb( config=config_copy, project=self.project, entity=self.entity, - experiment_dir=config["experiment_dir"], + experiment_dir=experiment_dir, resumable=resumable, dir=wandb_dir, **kwargs,