@@ -83,7 +83,7 @@ def any_lightning_module_function_or_hook(self):
8383 tracking_uri: Address of local or remote tracking server.
8484 If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls
8585 back to `file:<save_dir>`.
86- tags: A dictionary tags for the experiment .
86+ tags: A dictionary of tags to be set on the run .
8787 save_dir: A path to a local directory where the MLflow runs get saved.
8888 Defaults to `./mlruns` if `tracking_uri` is not provided.
8989 Has no effect if `tracking_uri` is provided.
@@ -96,6 +96,8 @@ def any_lightning_module_function_or_hook(self):
9696 which also logs every checkpoint during training.
9797 * if ``log_model == False`` (default), no checkpoint is logged.
9898
99+ experiment_tags: A dictionary of tags to set on the experiment. Has no effect if the experiment already
100+ exists.
99101 prefix: A string to put at the beginning of metric keys.
100102 artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
101103 default.
@@ -120,6 +122,7 @@ def __init__(
120122 prefix : str = "" ,
121123 artifact_location : Optional [str ] = None ,
122124 run_id : Optional [str ] = None ,
125+ experiment_tags : Optional [Dict [str , Any ]] = None ,
123126 ):
124127 if not _MLFLOW_AVAILABLE :
125128 raise ModuleNotFoundError (str (_MLFLOW_AVAILABLE ))
@@ -133,6 +136,7 @@ def __init__(
133136 self ._run_name = run_name
134137 self ._run_id = run_id
135138 self .tags = tags
139+ self ._experiment_tags = experiment_tags
136140 self ._log_model = log_model
137141 self ._logged_model_time : Dict [str , float ] = {}
138142 self ._checkpoint_callback : Optional [ModelCheckpoint ] = None
@@ -178,6 +182,9 @@ def experiment(self) -> "MlflowClient":
178182 self ._experiment_id = self ._mlflow_client .create_experiment (
179183 name = self ._experiment_name , artifact_location = self ._artifact_location
180184 )
185+ if self ._experiment_tags :
186+ for key , value in self ._experiment_tags .items ():
187+ self ._mlflow_client .set_experiment_tag (self ._experiment_id , key , value )
181188
182189 if self ._run_id is None :
183190 if self ._run_name is not None :
0 commit comments