Skip to content

Commit 37d5d32

Browse files
committed
feat(logger): add experiment_tags parameter to MLFlowLogger
1 parent d25014d commit 37d5d32

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

src/lightning/pytorch/loggers/mlflow.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def any_lightning_module_function_or_hook(self):
8383
tracking_uri: Address of local or remote tracking server.
8484
If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls
8585
back to `file:<save_dir>`.
86-
tags: A dictionary tags for the experiment.
86+
tags: A dictionary of tags to be set on the run.
8787
save_dir: A path to a local directory where the MLflow runs get saved.
8888
Defaults to `./mlruns` if `tracking_uri` is not provided.
8989
Has no effect if `tracking_uri` is provided.
@@ -96,6 +96,8 @@ def any_lightning_module_function_or_hook(self):
9696
which also logs every checkpoint during training.
9797
* if ``log_model == False`` (default), no checkpoint is logged.
9898
99+
experiment_tags: A dictionary of tags to set on the experiment. Has no effect if the experiment already
100+
exists.
99101
prefix: A string to put at the beginning of metric keys.
100102
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
101103
default.
@@ -120,6 +122,7 @@ def __init__(
120122
prefix: str = "",
121123
artifact_location: Optional[str] = None,
122124
run_id: Optional[str] = None,
125+
experiment_tags: Optional[Dict[str, Any]] = None,
123126
):
124127
if not _MLFLOW_AVAILABLE:
125128
raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE))
@@ -133,6 +136,7 @@ def __init__(
133136
self._run_name = run_name
134137
self._run_id = run_id
135138
self.tags = tags
139+
self._experiment_tags = experiment_tags
136140
self._log_model = log_model
137141
self._logged_model_time: Dict[str, float] = {}
138142
self._checkpoint_callback: Optional[ModelCheckpoint] = None
@@ -178,6 +182,9 @@ def experiment(self) -> "MlflowClient":
178182
self._experiment_id = self._mlflow_client.create_experiment(
179183
name=self._experiment_name, artifact_location=self._artifact_location
180184
)
185+
if self._experiment_tags:
186+
for key, value in self._experiment_tags.items():
187+
self._mlflow_client.set_experiment_tag(self._experiment_id, key, value)
181188

182189
if self._run_id is None:
183190
if self._run_name is not None:

0 commit comments

Comments
 (0)