Skip to content

Commit

Permalink
Logging resolved config (#2274)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ankur-singh authored Jan 21, 2025
1 parent 779569e commit 75965d4
Showing 1 changed file with 52 additions and 29 deletions.
81 changes: 52 additions & 29 deletions torchtune/training/metric_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,31 @@
log = get_logger("DEBUG")


def save_config(config: DictConfig) -> Path:
"""
Save the OmegaConf configuration to a YAML file at `{config.output_dir}/torchtune_config.yaml`.
Args:
config (DictConfig): The OmegaConf config object to be saved. It must contain an `output_dir` attribute
specifying where the configuration file should be saved.
Returns:
Path: The path to the saved configuration file.
Note:
If the specified `output_dir` does not exist, it will be created.
"""
try:
output_dir = Path(config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

output_config_fname = output_dir / "torchtune_config.yaml"
OmegaConf.save(config, output_config_fname)
return output_config_fname
except Exception as e:
log.warning(f"Error saving config.\nError: \n{e}.")


class MetricLoggerInterface(Protocol):
"""Abstract metric logger."""

Expand All @@ -42,7 +67,7 @@ def log(
pass

def log_config(self, config: DictConfig) -> None:
"""Logs the config
"""Logs the config as file
Args:
config (DictConfig): config to log
Expand Down Expand Up @@ -99,6 +124,9 @@ def log(self, name: str, data: Scalar, step: int) -> None:
self._file.write(f"Step {step} | {name}:{data}\n")
self._file.flush()

def log_config(self, config: DictConfig) -> None:
_ = save_config(config)

def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
self._file.write(f"Step {step} | ")
for name, data in payload.items():
Expand All @@ -119,6 +147,9 @@ class StdoutLogger(MetricLoggerInterface):
def log(self, name: str, data: Scalar, step: int) -> None:
print(f"Step {step} | {name}:{data}")

def log_config(self, config: DictConfig) -> None:
_ = save_config(config)

def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
print(f"Step {step} | ", end="")
for name, data in payload.items():
Expand Down Expand Up @@ -183,6 +214,10 @@ def __init__(
# Use dir if specified, otherwise use log_dir.
self.log_dir = kwargs.pop("dir", log_dir)

# create log_dir if missing
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)

_, self.rank = get_world_size_and_rank()

if self._wandb.run is None and self.rank == 0:
Expand Down Expand Up @@ -219,23 +254,16 @@ def log_config(self, config: DictConfig) -> None:
self._wandb.config.update(
resolved, allow_val_change=self.config_allow_val_change
)
try:
output_config_fname = Path(
os.path.join(
config.output_dir,
"torchtune_config.yaml",
)
)
OmegaConf.save(config, output_config_fname)

log.info(f"Logging {output_config_fname} to W&B under Files")
# Also try to save the config as a file
output_config_fname = save_config(config)
try:
self._wandb.save(
output_config_fname, base_path=output_config_fname.parent
)

except Exception as e:
log.warning(
f"Error saving {output_config_fname} to W&B.\nError: \n{e}."
f"Error uploading {output_config_fname} to W&B.\nError: \n{e}."
"Don't worry the config will be logged the W&B workspace"
)

Expand Down Expand Up @@ -305,6 +333,9 @@ def log(self, name: str, data: Scalar, step: int) -> None:
if self._writer:
self._writer.add_scalar(name, data, global_step=step, new_style=True)

def log_config(self, config: DictConfig) -> None:
_ = save_config(config)

def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
for name, data in payload.items():
self.log(name, data, step)
Expand Down Expand Up @@ -387,13 +418,16 @@ def __init__(
"Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'."
) from e

# Remove 'log_dir' from kwargs as it is not a valid argument for comet_ml.ExperimentConfig
if "log_dir" in kwargs:
del kwargs["log_dir"]

_, self.rank = get_world_size_and_rank()

# Declare it early so further methods don't crash in case of
# Experiment Creation failure due to mis-named configuration for
# example
self.experiment = None

if self.rank == 0:
self.experiment = comet_ml.start(
api_key=api_key,
Expand Down Expand Up @@ -421,24 +455,13 @@ def log_config(self, config: DictConfig) -> None:
self.experiment.log_parameters(resolved)

# Also try to save the config as a file
output_config_fname = save_config(config)
try:
self._log_config_as_file(config)
self.experiment.log_asset(
output_config_fname, file_name=output_config_fname.name
)
except Exception as e:
log.warning(f"Error saving Config to disk.\nError: \n{e}.")
return

def _log_config_as_file(self, config: DictConfig):
output_config_fname = Path(
os.path.join(
config.checkpointer.checkpoint_dir,
"torchtune_config.yaml",
)
)
OmegaConf.save(config, output_config_fname)

self.experiment.log_asset(
output_config_fname, file_name="torchtune_config.yaml"
)
log.warning(f"Failed to upload config to Comet assets. Error: {e}")

def close(self) -> None:
if self.experiment is not None:
Expand Down

0 comments on commit 75965d4

Please sign in to comment.