23
23
log = get_logger ("DEBUG" )
24
24
25
25
26
- def save_config (config ):
26
+ def save_config (config : DictConfig ) -> Path :
27
+ """
28
+ Save the OmegaConf configuration to a YAML file at `{config.output_dir}/torchtune_config.yaml`.
29
+
30
+ Args:
31
+ config (DictConfig): The OmegaConf config object to be saved. It must contain an `output_dir` attribute
32
+ specifying where the configuration file should be saved.
33
+
34
+ Returns:
35
+ Path: The path to the saved configuration file.
36
+
37
+ Note:
38
+ If the specified `output_dir` does not exist, it will be created.
39
+ """
27
40
try :
28
- output_config_fname = Path (
29
- os .path .join (
30
- config .output_dir ,
31
- "torchtune_config.yaml" ,
32
- )
33
- )
34
- log .info (f"Writing resolved config to { output_config_fname } " )
41
+ output_dir = Path (config .output_dir )
42
+ output_dir .mkdir (parents = True , exist_ok = True )
43
+
44
+ output_config_fname = output_dir / "torchtune_config.yaml"
45
+ log .info (f"Writing config to { output_config_fname } " )
35
46
OmegaConf .save (config , output_config_fname )
47
+ return output_config_fname
36
48
except Exception as e :
37
- log .warning (f"Error saving { output_config_fname } to disk .\n Error: \n { e } ." )
49
+ log .warning (f"Error saving config to { output_config_fname } .\n Error: \n { e } ." )
38
50
39
51
40
52
class MetricLoggerInterface (Protocol ):
@@ -56,7 +68,7 @@ def log(
56
68
pass
57
69
58
70
def log_config (self , config : DictConfig ) -> None :
59
- """Logs the config
71
+ """Logs the config as file
60
72
61
73
Args:
62
74
config (DictConfig): config to log
@@ -114,7 +126,7 @@ def log(self, name: str, data: Scalar, step: int) -> None:
114
126
self ._file .flush ()
115
127
116
128
def log_config (self , config : DictConfig ) -> None :
117
- save_config (config )
129
+ _ = save_config (config )
118
130
119
131
def log_dict (self , payload : Mapping [str , Scalar ], step : int ) -> None :
120
132
self ._file .write (f"Step { step } | " )
@@ -136,6 +148,9 @@ class StdoutLogger(MetricLoggerInterface):
136
148
def log (self , name : str , data : Scalar , step : int ) -> None :
137
149
print (f"Step { step } | { name } :{ data } " )
138
150
151
+ def log_config (self , config : DictConfig ) -> None :
152
+ _ = save_config (config )
153
+
139
154
def log_dict (self , payload : Mapping [str , Scalar ], step : int ) -> None :
140
155
print (f"Step { step } | " , end = "" )
141
156
for name , data in payload .items ():
@@ -200,6 +215,10 @@ def __init__(
200
215
# Use dir if specified, otherwise use log_dir.
201
216
self .log_dir = kwargs .pop ("dir" , log_dir )
202
217
218
+ # create log_dir if missing
219
+ if not os .path .exists (self .log_dir ):
220
+ os .makedirs (self .log_dir )
221
+
203
222
_ , self .rank = get_world_size_and_rank ()
204
223
205
224
if self ._wandb .run is None and self .rank == 0 :
@@ -236,23 +255,17 @@ def log_config(self, config: DictConfig) -> None:
236
255
self ._wandb .config .update (
237
256
resolved , allow_val_change = self .config_allow_val_change
238
257
)
239
- try :
240
- output_config_fname = Path (
241
- os .path .join (
242
- config .output_dir ,
243
- "torchtune_config.yaml" ,
244
- )
245
- )
246
- OmegaConf .save (config , output_config_fname )
247
258
248
- log .info (f"Logging { output_config_fname } to W&B under Files" )
259
+ # Also try to save the config as a file
260
+ output_config_fname = save_config (config )
261
+ try :
262
+ log .info (f"Uploading { output_config_fname } to W&B under Files" )
249
263
self ._wandb .save (
250
264
output_config_fname , base_path = output_config_fname .parent
251
265
)
252
-
253
266
except Exception as e :
254
267
log .warning (
255
- f"Error saving { output_config_fname } to W&B.\n Error: \n { e } ."
268
+ f"Error uploading { output_config_fname } to W&B.\n Error: \n { e } ."
256
269
"Don't worry the config will be logged the W&B workspace"
257
270
)
258
271
@@ -322,6 +335,9 @@ def log(self, name: str, data: Scalar, step: int) -> None:
322
335
if self ._writer :
323
336
self ._writer .add_scalar (name , data , global_step = step , new_style = True )
324
337
338
+ def log_config (self , config : DictConfig ) -> None :
339
+ _ = save_config (config )
340
+
325
341
def log_dict (self , payload : Mapping [str , Scalar ], step : int ) -> None :
326
342
for name , data in payload .items ():
327
343
self .log (name , data , step )
@@ -404,13 +420,15 @@ def __init__(
404
420
"Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'."
405
421
) from e
406
422
423
+ # Remove 'log_dir' from kwargs as it is not a valid argument for comet_ml.ExperimentConfig
424
+ del kwargs ["log_dir" ]
425
+
407
426
_ , self .rank = get_world_size_and_rank ()
408
427
409
428
# Declare it early so further methods don't crash in case of
410
429
# Experiment Creation failure due to mis-named configuration for
411
430
# example
412
431
self .experiment = None
413
-
414
432
if self .rank == 0 :
415
433
self .experiment = comet_ml .start (
416
434
api_key = api_key ,
@@ -438,24 +456,14 @@ def log_config(self, config: DictConfig) -> None:
438
456
self .experiment .log_parameters (resolved )
439
457
440
458
# Also try to save the config as a file
459
+ output_config_fname = save_config (config )
441
460
try :
442
- self ._log_config_as_file (config )
461
+ log .info (f"Uploading { output_config_fname } to Comet as an asset." )
462
+ self .experiment .log_asset (
463
+ output_config_fname , file_name = output_config_fname .name
464
+ )
443
465
except Exception as e :
444
- log .warning (f"Error saving Config to disk.\n Error: \n { e } ." )
445
- return
446
-
447
- def _log_config_as_file (self , config : DictConfig ):
448
- output_config_fname = Path (
449
- os .path .join (
450
- config .checkpointer .checkpoint_dir ,
451
- "torchtune_config.yaml" ,
452
- )
453
- )
454
- OmegaConf .save (config , output_config_fname )
455
-
456
- self .experiment .log_asset (
457
- output_config_fname , file_name = "torchtune_config.yaml"
458
- )
466
+ log .warning (f"Failed to upload config to Comet assets. Error: { e } " )
459
467
460
468
def close (self ) -> None :
461
469
if self .experiment is not None :
0 commit comments