23
23
log = get_logger ("DEBUG" )
24
24
25
25
26
+ def save_config (config : DictConfig ) -> Path :
27
+ """
28
+ Save the OmegaConf configuration to a YAML file at `{config.output_dir}/torchtune_config.yaml`.
29
+
30
+ Args:
31
+ config (DictConfig): The OmegaConf config object to be saved. It must contain an `output_dir` attribute
32
+ specifying where the configuration file should be saved.
33
+
34
+ Returns:
35
+ Path: The path to the saved configuration file.
36
+
37
+ Note:
38
+ If the specified `output_dir` does not exist, it will be created.
39
+ """
40
+ try :
41
+ output_dir = Path (config .output_dir )
42
+ output_dir .mkdir (parents = True , exist_ok = True )
43
+
44
+ output_config_fname = output_dir / "torchtune_config.yaml"
45
+ OmegaConf .save (config , output_config_fname )
46
+ return output_config_fname
47
+ except Exception as e :
48
+ log .warning (f"Error saving config.\n Error: \n { e } ." )
49
+
50
+
26
51
class MetricLoggerInterface (Protocol ):
27
52
"""Abstract metric logger."""
28
53
@@ -42,7 +67,7 @@ def log(
42
67
pass
43
68
44
69
def log_config (self , config : DictConfig ) -> None :
45
- """Logs the config
70
+ """Logs the config as file
46
71
47
72
Args:
48
73
config (DictConfig): config to log
@@ -99,6 +124,9 @@ def log(self, name: str, data: Scalar, step: int) -> None:
99
124
self ._file .write (f"Step { step } | { name } :{ data } \n " )
100
125
self ._file .flush ()
101
126
127
+ def log_config (self , config : DictConfig ) -> None :
128
+ _ = save_config (config )
129
+
102
130
def log_dict (self , payload : Mapping [str , Scalar ], step : int ) -> None :
103
131
self ._file .write (f"Step { step } | " )
104
132
for name , data in payload .items ():
@@ -119,6 +147,9 @@ class StdoutLogger(MetricLoggerInterface):
119
147
def log (self , name : str , data : Scalar , step : int ) -> None :
120
148
print (f"Step { step } | { name } :{ data } " )
121
149
150
+ def log_config (self , config : DictConfig ) -> None :
151
+ _ = save_config (config )
152
+
122
153
def log_dict (self , payload : Mapping [str , Scalar ], step : int ) -> None :
123
154
print (f"Step { step } | " , end = "" )
124
155
for name , data in payload .items ():
@@ -183,6 +214,10 @@ def __init__(
183
214
# Use dir if specified, otherwise use log_dir.
184
215
self .log_dir = kwargs .pop ("dir" , log_dir )
185
216
217
+ # create log_dir if missing
218
+ if not os .path .exists (self .log_dir ):
219
+ os .makedirs (self .log_dir )
220
+
186
221
_ , self .rank = get_world_size_and_rank ()
187
222
188
223
if self ._wandb .run is None and self .rank == 0 :
@@ -219,23 +254,16 @@ def log_config(self, config: DictConfig) -> None:
219
254
self ._wandb .config .update (
220
255
resolved , allow_val_change = self .config_allow_val_change
221
256
)
222
- try :
223
- output_config_fname = Path (
224
- os .path .join (
225
- config .output_dir ,
226
- "torchtune_config.yaml" ,
227
- )
228
- )
229
- OmegaConf .save (config , output_config_fname )
230
257
231
- log .info (f"Logging { output_config_fname } to W&B under Files" )
258
+ # Also try to save the config as a file
259
+ output_config_fname = save_config (config )
260
+ try :
232
261
self ._wandb .save (
233
262
output_config_fname , base_path = output_config_fname .parent
234
263
)
235
-
236
264
except Exception as e :
237
265
log .warning (
238
- f"Error saving { output_config_fname } to W&B.\n Error: \n { e } ."
266
+ f"Error uploading { output_config_fname } to W&B.\n Error: \n { e } ."
239
267
"Don't worry the config will be logged the W&B workspace"
240
268
)
241
269
@@ -305,6 +333,9 @@ def log(self, name: str, data: Scalar, step: int) -> None:
305
333
if self ._writer :
306
334
self ._writer .add_scalar (name , data , global_step = step , new_style = True )
307
335
336
+ def log_config (self , config : DictConfig ) -> None :
337
+ _ = save_config (config )
338
+
308
339
def log_dict (self , payload : Mapping [str , Scalar ], step : int ) -> None :
309
340
for name , data in payload .items ():
310
341
self .log (name , data , step )
@@ -387,13 +418,16 @@ def __init__(
387
418
"Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'."
388
419
) from e
389
420
421
+ # Remove 'log_dir' from kwargs as it is not a valid argument for comet_ml.ExperimentConfig
422
+ if "log_dir" in kwargs :
423
+ del kwargs ["log_dir" ]
424
+
390
425
_ , self .rank = get_world_size_and_rank ()
391
426
392
427
# Declare it early so further methods don't crash in case of
393
428
# Experiment Creation failure due to mis-named configuration for
394
429
# example
395
430
self .experiment = None
396
-
397
431
if self .rank == 0 :
398
432
self .experiment = comet_ml .start (
399
433
api_key = api_key ,
@@ -421,24 +455,13 @@ def log_config(self, config: DictConfig) -> None:
421
455
self .experiment .log_parameters (resolved )
422
456
423
457
# Also try to save the config as a file
458
+ output_config_fname = save_config (config )
424
459
try :
425
- self ._log_config_as_file (config )
460
+ self .experiment .log_asset (
461
+ output_config_fname , file_name = output_config_fname .name
462
+ )
426
463
except Exception as e :
427
- log .warning (f"Error saving Config to disk.\n Error: \n { e } ." )
428
- return
429
-
430
- def _log_config_as_file (self , config : DictConfig ):
431
- output_config_fname = Path (
432
- os .path .join (
433
- config .checkpointer .checkpoint_dir ,
434
- "torchtune_config.yaml" ,
435
- )
436
- )
437
- OmegaConf .save (config , output_config_fname )
438
-
439
- self .experiment .log_asset (
440
- output_config_fname , file_name = "torchtune_config.yaml"
441
- )
464
+ log .warning (f"Failed to upload config to Comet assets. Error: { e } " )
442
465
443
466
def close (self ) -> None :
444
467
if self .experiment is not None :
0 commit comments