Skip to content

Commit 1e90052

Browse files
committed
refactor log_config implementation for each logger class
1 parent 1885cf3 commit 1e90052

File tree

1 file changed

+47
-39
lines changed

1 file changed

+47
-39
lines changed

torchtune/training/metric_logging.py

+47-39
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,30 @@
2323
log = get_logger("DEBUG")
2424

2525

26-
def save_config(config):
26+
def save_config(config: DictConfig) -> Path:
27+
"""
28+
Save the OmegaConf configuration to a YAML file at `{config.output_dir}/torchtune_config.yaml`.
29+
30+
Args:
31+
config (DictConfig): The OmegaConf config object to be saved. It must contain an `output_dir` attribute
32+
specifying where the configuration file should be saved.
33+
34+
Returns:
35+
Path: The path to the saved configuration file.
36+
37+
Note:
38+
If the specified `output_dir` does not exist, it will be created.
39+
"""
2740
try:
28-
output_config_fname = Path(
29-
os.path.join(
30-
config.output_dir,
31-
"torchtune_config.yaml",
32-
)
33-
)
34-
log.info(f"Writing resolved config to {output_config_fname}")
41+
output_dir = Path(config.output_dir)
42+
output_dir.mkdir(parents=True, exist_ok=True)
43+
44+
output_config_fname = output_dir / "torchtune_config.yaml"
45+
log.info(f"Writing config to {output_config_fname}")
3546
OmegaConf.save(config, output_config_fname)
47+
return output_config_fname
3648
except Exception as e:
37-
log.warning(f"Error saving {output_config_fname} to disk.\nError: \n{e}.")
49+
log.warning(f"Error saving config to {output_config_fname}.\nError: \n{e}.")
3850

3951

4052
class MetricLoggerInterface(Protocol):
@@ -56,7 +68,7 @@ def log(
5668
pass
5769

5870
def log_config(self, config: DictConfig) -> None:
59-
"""Logs the config
71+
"""Logs the config as file
6072
6173
Args:
6274
config (DictConfig): config to log
@@ -114,7 +126,7 @@ def log(self, name: str, data: Scalar, step: int) -> None:
114126
self._file.flush()
115127

116128
def log_config(self, config: DictConfig) -> None:
117-
save_config(config)
129+
_ = save_config(config)
118130

119131
def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
120132
self._file.write(f"Step {step} | ")
@@ -136,6 +148,9 @@ class StdoutLogger(MetricLoggerInterface):
136148
def log(self, name: str, data: Scalar, step: int) -> None:
137149
print(f"Step {step} | {name}:{data}")
138150

151+
def log_config(self, config: DictConfig) -> None:
152+
_ = save_config(config)
153+
139154
def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
140155
print(f"Step {step} | ", end="")
141156
for name, data in payload.items():
@@ -200,6 +215,10 @@ def __init__(
200215
# Use dir if specified, otherwise use log_dir.
201216
self.log_dir = kwargs.pop("dir", log_dir)
202217

218+
# create log_dir if missing
219+
if not os.path.exists(self.log_dir):
220+
os.makedirs(self.log_dir)
221+
203222
_, self.rank = get_world_size_and_rank()
204223

205224
if self._wandb.run is None and self.rank == 0:
@@ -236,23 +255,17 @@ def log_config(self, config: DictConfig) -> None:
236255
self._wandb.config.update(
237256
resolved, allow_val_change=self.config_allow_val_change
238257
)
239-
try:
240-
output_config_fname = Path(
241-
os.path.join(
242-
config.output_dir,
243-
"torchtune_config.yaml",
244-
)
245-
)
246-
OmegaConf.save(config, output_config_fname)
247258

248-
log.info(f"Logging {output_config_fname} to W&B under Files")
259+
# Also try to save the config as a file
260+
output_config_fname = save_config(config)
261+
try:
262+
log.info(f"Uploading {output_config_fname} to W&B under Files")
249263
self._wandb.save(
250264
output_config_fname, base_path=output_config_fname.parent
251265
)
252-
253266
except Exception as e:
254267
log.warning(
255-
f"Error saving {output_config_fname} to W&B.\nError: \n{e}."
268+
f"Error uploading {output_config_fname} to W&B.\nError: \n{e}."
256269
"Don't worry the config will be logged the W&B workspace"
257270
)
258271

@@ -322,6 +335,9 @@ def log(self, name: str, data: Scalar, step: int) -> None:
322335
if self._writer:
323336
self._writer.add_scalar(name, data, global_step=step, new_style=True)
324337

338+
def log_config(self, config: DictConfig) -> None:
339+
_ = save_config(config)
340+
325341
def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
326342
for name, data in payload.items():
327343
self.log(name, data, step)
@@ -404,13 +420,15 @@ def __init__(
404420
"Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'."
405421
) from e
406422

423+
# Remove 'log_dir' from kwargs as it is not a valid argument for comet_ml.ExperimentConfig
424+
del kwargs["log_dir"]
425+
407426
_, self.rank = get_world_size_and_rank()
408427

409428
# Declare it early so further methods don't crash in case of
410429
# Experiment Creation failure due to mis-named configuration for
411430
# example
412431
self.experiment = None
413-
414432
if self.rank == 0:
415433
self.experiment = comet_ml.start(
416434
api_key=api_key,
@@ -438,24 +456,14 @@ def log_config(self, config: DictConfig) -> None:
438456
self.experiment.log_parameters(resolved)
439457

440458
# Also try to save the config as a file
459+
output_config_fname = save_config(config)
441460
try:
442-
self._log_config_as_file(config)
461+
log.info(f"Uploading {output_config_fname} to Comet as an asset.")
462+
self.experiment.log_asset(
463+
output_config_fname, file_name=output_config_fname.name
464+
)
443465
except Exception as e:
444-
log.warning(f"Error saving Config to disk.\nError: \n{e}.")
445-
return
446-
447-
def _log_config_as_file(self, config: DictConfig):
448-
output_config_fname = Path(
449-
os.path.join(
450-
config.checkpointer.checkpoint_dir,
451-
"torchtune_config.yaml",
452-
)
453-
)
454-
OmegaConf.save(config, output_config_fname)
455-
456-
self.experiment.log_asset(
457-
output_config_fname, file_name="torchtune_config.yaml"
458-
)
466+
log.warning(f"Failed to upload config to Comet assets. Error: {e}")
459467

460468
def close(self) -> None:
461469
if self.experiment is not None:

0 commit comments

Comments
 (0)