From 75965d4281b9b76c454630d015221b9933c77bf3 Mon Sep 17 00:00:00 2001 From: Ankur Singh Date: Tue, 21 Jan 2025 10:44:53 -0800 Subject: [PATCH] Logging resolved config (#2274) --- torchtune/training/metric_logging.py | 81 ++++++++++++++++++---------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index 42aa1f9d72..dde1619194 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -23,6 +23,31 @@ log = get_logger("DEBUG") +def save_config(config: DictConfig) -> Path: + """ + Save the OmegaConf configuration to a YAML file at `{config.output_dir}/torchtune_config.yaml`. + + Args: + config (DictConfig): The OmegaConf config object to be saved. It must contain an `output_dir` attribute + specifying where the configuration file should be saved. + + Returns: + Path: The path to the saved configuration file. + + Note: + If the specified `output_dir` does not exist, it will be created. + """ + try: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + output_config_fname = output_dir / "torchtune_config.yaml" + OmegaConf.save(config, output_config_fname) + return output_config_fname + except Exception as e: + log.warning(f"Error saving config.\nError: \n{e}.") + + class MetricLoggerInterface(Protocol): """Abstract metric logger.""" @@ -42,7 +67,7 @@ def log( pass def log_config(self, config: DictConfig) -> None: - """Logs the config + """Logs the config as file Args: config (DictConfig): config to log @@ -99,6 +124,9 @@ def log(self, name: str, data: Scalar, step: int) -> None: self._file.write(f"Step {step} | {name}:{data}\n") self._file.flush() + def log_config(self, config: DictConfig) -> None: + _ = save_config(config) + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: self._file.write(f"Step {step} | ") for name, data in payload.items(): @@ -119,6 +147,9 @@ class StdoutLogger(MetricLoggerInterface): def log(self, name: str, data: Scalar, step: int) -> None: print(f"Step {step} | {name}:{data}") + def log_config(self, config: DictConfig) -> None: + _ = save_config(config) + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: print(f"Step {step} | ", end="") for name, data in payload.items(): @@ -183,6 +214,10 @@ def __init__( # Use dir if specified, otherwise use log_dir. self.log_dir = kwargs.pop("dir", log_dir) + # create log_dir if missing + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + _, self.rank = get_world_size_and_rank() if self._wandb.run is None and self.rank == 0: @@ -219,23 +254,16 @@ def log_config(self, config: DictConfig) -> None: self._wandb.config.update( resolved, allow_val_change=self.config_allow_val_change ) - try: - output_config_fname = Path( - os.path.join( - config.output_dir, - "torchtune_config.yaml", - ) - ) - OmegaConf.save(config, output_config_fname) - log.info(f"Logging {output_config_fname} to W&B under Files") + # Also try to save the config as a file + output_config_fname = save_config(config) + try: self._wandb.save( output_config_fname, base_path=output_config_fname.parent ) - except Exception as e: log.warning( - f"Error saving {output_config_fname} to W&B.\nError: \n{e}." + f"Error uploading {output_config_fname} to W&B.\nError: \n{e}." "Don't worry the config will be logged the W&B workspace" ) @@ -305,6 +333,9 @@ def log(self, name: str, data: Scalar, step: int) -> None: if self._writer: self._writer.add_scalar(name, data, global_step=step, new_style=True) + def log_config(self, config: DictConfig) -> None: + _ = save_config(config) + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: for name, data in payload.items(): self.log(name, data, step) @@ -387,13 +418,16 @@ def __init__( "Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'." ) from e + # Remove 'log_dir' from kwargs as it is not a valid argument for comet_ml.ExperimentConfig + if "log_dir" in kwargs: + del kwargs["log_dir"] + _, self.rank = get_world_size_and_rank() # Declare it early so further methods don't crash in case of # Experiment Creation failure due to mis-named configuration for # example self.experiment = None - if self.rank == 0: self.experiment = comet_ml.start( api_key=api_key, @@ -421,24 +455,13 @@ def log_config(self, config: DictConfig) -> None: self.experiment.log_parameters(resolved) # Also try to save the config as a file + output_config_fname = save_config(config) try: - self._log_config_as_file(config) + self.experiment.log_asset( + output_config_fname, file_name=output_config_fname.name + ) except Exception as e: - log.warning(f"Error saving Config to disk.\nError: \n{e}.") - return - - def _log_config_as_file(self, config: DictConfig): - output_config_fname = Path( - os.path.join( - config.checkpointer.checkpoint_dir, - "torchtune_config.yaml", - ) - ) - OmegaConf.save(config, output_config_fname) - - self.experiment.log_asset( - output_config_fname, file_name="torchtune_config.yaml" - ) + log.warning(f"Failed to upload config to Comet assets. Error: {e}") def close(self) -> None: if self.experiment is not None: