@@ -36,6 +36,8 @@ class LoggingConfig:
3636 log_to_wandb: Whether to log to Weights & Biases.
3737 log_format: Format of the log messages.
3838 level: Sets the logging level.
39+ wandb_dir_in_experiment_dir: Whether to create the wandb_dir in the
40+ experiment_dir or in local /tmp (default False).
3941 """
4042
4143 project : str = "ace"
@@ -45,6 +47,7 @@ class LoggingConfig:
4547 log_to_wandb : bool = True
4648 log_format : str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
4749 level : str | int = logging .INFO
50+ wandb_dir_in_experiment_dir : bool = False
4851
4952 def __post_init__ (self ):
5053 self ._dist = Distributed .get_instance ()
@@ -76,7 +79,6 @@ def configure_wandb(
7679 self ,
7780 config : Mapping [str , Any ],
7881 env_vars : Mapping [str , Any ] | None = None ,
79- wandb_dir : str | None = DEFAULT_TMP_DIR ,
8082 resumable : bool = True ,
8183 resume : Any = None ,
8284 ** kwargs ,
@@ -94,14 +96,21 @@ def configure_wandb(
9496 )
9597 elif env_vars is not None :
9698 config_copy ["environment" ] = env_vars
99+
100+ experiment_dir = config ["experiment_dir" ]
101+ if self .wandb_dir_in_experiment_dir :
102+ wandb_dir = experiment_dir
103+ else :
104+ wandb_dir = DEFAULT_TMP_DIR
105+
97106 # must ensure wandb.configure is called before wandb.init
98107 wandb = WandB .get_instance ()
99108 wandb .configure (log_to_wandb = self .log_to_wandb )
100109 wandb .init (
101110 config = config_copy ,
102111 project = self .project ,
103112 entity = self .entity ,
104- experiment_dir = config [ " experiment_dir" ] ,
113+ experiment_dir = experiment_dir ,
105114 resumable = resumable ,
106115 dir = wandb_dir ,
107116 ** kwargs ,
0 commit comments