Skip to content

Commit 75965d4

Browse files
authored
Logging resolved config (#2274)
1 parent 779569e commit 75965d4

File tree

1 file changed

+52
-29
lines changed

1 file changed

+52
-29
lines changed

torchtune/training/metric_logging.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,31 @@
2323
log = get_logger("DEBUG")
2424

2525

26+
def save_config(config: DictConfig) -> Path:
27+
"""
28+
Save the OmegaConf configuration to a YAML file at `{config.output_dir}/torchtune_config.yaml`.
29+
30+
Args:
31+
config (DictConfig): The OmegaConf config object to be saved. It must contain an `output_dir` attribute
32+
specifying where the configuration file should be saved.
33+
34+
Returns:
35+
Path: The path to the saved configuration file.
36+
37+
Note:
38+
If the specified `output_dir` does not exist, it will be created.
39+
"""
40+
try:
41+
output_dir = Path(config.output_dir)
42+
output_dir.mkdir(parents=True, exist_ok=True)
43+
44+
output_config_fname = output_dir / "torchtune_config.yaml"
45+
OmegaConf.save(config, output_config_fname)
46+
return output_config_fname
47+
except Exception as e:
48+
log.warning(f"Error saving config.\nError: \n{e}.")
49+
50+
2651
class MetricLoggerInterface(Protocol):
2752
"""Abstract metric logger."""
2853

@@ -42,7 +67,7 @@ def log(
4267
pass
4368

4469
def log_config(self, config: DictConfig) -> None:
45-
"""Logs the config
70+
"""Logs the config as file
4671
4772
Args:
4873
config (DictConfig): config to log
@@ -99,6 +124,9 @@ def log(self, name: str, data: Scalar, step: int) -> None:
99124
self._file.write(f"Step {step} | {name}:{data}\n")
100125
self._file.flush()
101126

127+
def log_config(self, config: DictConfig) -> None:
128+
_ = save_config(config)
129+
102130
def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
103131
self._file.write(f"Step {step} | ")
104132
for name, data in payload.items():
@@ -119,6 +147,9 @@ class StdoutLogger(MetricLoggerInterface):
119147
def log(self, name: str, data: Scalar, step: int) -> None:
120148
print(f"Step {step} | {name}:{data}")
121149

150+
def log_config(self, config: DictConfig) -> None:
151+
_ = save_config(config)
152+
122153
def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
123154
print(f"Step {step} | ", end="")
124155
for name, data in payload.items():
@@ -183,6 +214,10 @@ def __init__(
183214
# Use dir if specified, otherwise use log_dir.
184215
self.log_dir = kwargs.pop("dir", log_dir)
185216

217+
# create log_dir if missing
218+
if not os.path.exists(self.log_dir):
219+
os.makedirs(self.log_dir)
220+
186221
_, self.rank = get_world_size_and_rank()
187222

188223
if self._wandb.run is None and self.rank == 0:
@@ -219,23 +254,16 @@ def log_config(self, config: DictConfig) -> None:
219254
self._wandb.config.update(
220255
resolved, allow_val_change=self.config_allow_val_change
221256
)
222-
try:
223-
output_config_fname = Path(
224-
os.path.join(
225-
config.output_dir,
226-
"torchtune_config.yaml",
227-
)
228-
)
229-
OmegaConf.save(config, output_config_fname)
230257

231-
log.info(f"Logging {output_config_fname} to W&B under Files")
258+
# Also try to save the config as a file
259+
output_config_fname = save_config(config)
260+
try:
232261
self._wandb.save(
233262
output_config_fname, base_path=output_config_fname.parent
234263
)
235-
236264
except Exception as e:
237265
log.warning(
238-
f"Error saving {output_config_fname} to W&B.\nError: \n{e}."
266+
f"Error uploading {output_config_fname} to W&B.\nError: \n{e}."
239267
"Don't worry the config will be logged the W&B workspace"
240268
)
241269

@@ -305,6 +333,9 @@ def log(self, name: str, data: Scalar, step: int) -> None:
305333
if self._writer:
306334
self._writer.add_scalar(name, data, global_step=step, new_style=True)
307335

336+
def log_config(self, config: DictConfig) -> None:
337+
_ = save_config(config)
338+
308339
def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
309340
for name, data in payload.items():
310341
self.log(name, data, step)
@@ -387,13 +418,16 @@ def __init__(
387418
"Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'."
388419
) from e
389420

421+
# Remove 'log_dir' from kwargs as it is not a valid argument for comet_ml.ExperimentConfig
422+
if "log_dir" in kwargs:
423+
del kwargs["log_dir"]
424+
390425
_, self.rank = get_world_size_and_rank()
391426

392427
# Declare it early so further methods don't crash in case of
393428
# Experiment Creation failure due to mis-named configuration for
394429
# example
395430
self.experiment = None
396-
397431
if self.rank == 0:
398432
self.experiment = comet_ml.start(
399433
api_key=api_key,
@@ -421,24 +455,13 @@ def log_config(self, config: DictConfig) -> None:
421455
self.experiment.log_parameters(resolved)
422456

423457
# Also try to save the config as a file
458+
output_config_fname = save_config(config)
424459
try:
425-
self._log_config_as_file(config)
460+
self.experiment.log_asset(
461+
output_config_fname, file_name=output_config_fname.name
462+
)
426463
except Exception as e:
427-
log.warning(f"Error saving Config to disk.\nError: \n{e}.")
428-
return
429-
430-
def _log_config_as_file(self, config: DictConfig):
431-
output_config_fname = Path(
432-
os.path.join(
433-
config.checkpointer.checkpoint_dir,
434-
"torchtune_config.yaml",
435-
)
436-
)
437-
OmegaConf.save(config, output_config_fname)
438-
439-
self.experiment.log_asset(
440-
output_config_fname, file_name="torchtune_config.yaml"
441-
)
464+
log.warning(f"Failed to upload config to Comet assets. Error: {e}")
442465

443466
def close(self) -> None:
444467
if self.experiment is not None:

0 commit comments

Comments
 (0)