15
15
from enum import Enum
16
16
from typing import Optional , Any
17
17
18
- import mlflow .pyfunc
19
- import mlflow .pytorch
20
- import mlflow .sklearn
18
+ import mlflow
21
19
from mlflow .entities .model_registry import ModelVersion
22
- from mlflow .exceptions import RestException
20
+ from mlflow .exceptions import RestException , MlflowException
23
21
from mlflow .protos .databricks_pb2 import ErrorCode , RESOURCE_DOES_NOT_EXIST
24
22
from mlflow .tracking import MlflowClient
25
23
@@ -187,6 +185,43 @@ def load(
187
185
self ._save_in_cache (model_key , artifact_data )
188
186
return artifact_data
189
187
188
+ def load_multiple (
189
+ self ,
190
+ skeys : KEYS ,
191
+ dkeys : KEYS ,
192
+ ) -> Optional [ArtifactData ]:
193
+ """
194
+ Load multiple artifacts from the registry for pyfunc models.
195
+ Args:
196
+ skeys (KEYS): The source keys of the artifacts to load.
197
+ dkeys: dynamic key fields as list/tuple of strings.
198
+
199
+ Returns
200
+ -------
201
+ Optional[ArtifactData]: The loaded ArtifactData object if available otherwise None.
202
+ ArtifactData should contain a dictionary of artifacts.
203
+ """
204
+ loaded_model = self .load (skeys = skeys , dkeys = dkeys , artifact_type = "pyfunc" )
205
+ if loaded_model is None :
206
+ return None
207
+
208
+ try :
209
+ unwrapped_composite_model = loaded_model .artifact .unwrap_python_model ()
210
+ except MlflowException as e :
211
+ raise TypeError ("The loaded model is not a valid pyfunc Python model." ) from e
212
+ except AttributeError :
213
+ _LOGGER .exception ("The loaded model does not have an unwrap_python_model method" )
214
+ return None
215
+ except Exception :
216
+ _LOGGER .exception ("Unexpected error occurred while unwrapping python model." )
217
+ return None
218
+
219
+ return ArtifactData (
220
+ artifact = unwrapped_composite_model .dict_artifacts ,
221
+ metadata = loaded_model .metadata ,
222
+ extras = loaded_model .extras ,
223
+ )
224
+
190
225
@staticmethod
191
226
def __log_mlflow_err (mlflow_err : RestException , model_key : str ) -> None :
192
227
if ErrorCode .Value (mlflow_err .error_code ) == RESOURCE_DOES_NOT_EXIST :
@@ -225,7 +260,10 @@ def save(
225
260
handler = self .handler_from_type (artifact_type )
226
261
try :
227
262
mlflow .start_run (run_id = run_id )
228
- handler .log_model (artifact , "model" , registered_model_name = model_key )
263
+ if artifact_type == "pyfunc" :
264
+ handler .log_model ("model" , python_model = artifact , registered_model_name = model_key )
265
+ else :
266
+ handler .log_model (artifact , "model" , registered_model_name = model_key )
229
267
if metadata :
230
268
mlflow .log_params (metadata )
231
269
model_version = self .transition_stage (skeys = skeys , dkeys = dkeys )
@@ -238,6 +276,42 @@ def save(
238
276
finally :
239
277
mlflow .end_run ()
240
278
279
+ def save_multiple (
280
+ self ,
281
+ skeys : KEYS ,
282
+ dkeys : KEYS ,
283
+ dict_artifacts : dict [str , artifact_t ],
284
+ ** metadata : META_VT ,
285
+ ) -> Optional [ModelVersion ]:
286
+ """
287
+ Saves multiple artifacts into mlflow registry. The last save stores all the
288
+ artifact versions in the metadata.
289
+
290
+ Args:
291
+ ----
292
+ skeys (KEYS): Static key fields as a list or tuple of strings.
293
+ dkeys (KEYS): Dynamic key fields as a list or tuple of strings.
294
+ dict_artifacts (dict[str, artifact_t]): Dictionary of artifacts to save.
295
+ **metadata (META_VT): Additional metadata to be saved with the artifacts.
296
+
297
+ Returns
298
+ -------
299
+ Optional[ModelVersion]: An instance of the MLflow ModelVersion.
300
+
301
+ """
302
+ if len (dict_artifacts ) == 1 :
303
+ _LOGGER .warning (
304
+ "Only one artifact present in dict_artifacts. Saving directly is recommended."
305
+ )
306
+ multiple_artifacts = CompositeModel (skeys = skeys , dict_artifacts = dict_artifacts , ** metadata )
307
+ return self .save (
308
+ skeys = skeys ,
309
+ dkeys = dkeys ,
310
+ artifact = multiple_artifacts ,
311
+ artifact_type = "pyfunc" ,
312
+ ** metadata ,
313
+ )
314
+
241
315
@staticmethod
242
316
def is_artifact_stale (artifact_data : ArtifactData , freq_hr : int ) -> bool :
243
317
"""Returns whether the given artifact is stale or not, i.e. if
@@ -338,3 +412,45 @@ def __load_artifacts(
338
412
version_info .version ,
339
413
)
340
414
return model , metadata
415
+
416
+
417
+ class CompositeModel (mlflow .pyfunc .PythonModel ):
418
+ """A composite model that represents multiple artifacts.
419
+
420
+ This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load
421
+ multiple artifacts in the MLflow registry. It provides a convenient way to manage and
422
+ organize multiple artifacts associated with a single model.
423
+
424
+ Args:
425
+ skeys (KEYS): The static keys of the artifacts.
426
+ dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to
427
+ `KeyedArtifact` objects.
428
+ **metadata (META_VT): Additional metadata associated with the artifacts.
429
+
430
+ Methods
431
+ -------
432
+ predict: Not implemented for our use case.
433
+
434
+ Attributes
435
+ ----------
436
+ skeys (KEYS): The static keys of the artifacts.
437
+ dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to
438
+ `KeyedArtifact` objects.
439
+ metadata (META_VT): Additional metadata associated with the artifacts.
440
+ """
441
+
442
+ __slots__ = ("skeys" , "dict_artifacts" , "metadata" )
443
+
444
+ def __init__ (self , skeys : KEYS , dict_artifacts : dict [str , artifact_t ], ** metadata : META_VT ):
445
+ self .skeys = skeys
446
+ self .dict_artifacts = dict_artifacts
447
+ self .metadata = metadata
448
+
449
+ def predict (self , context , model_input , params : Optional [dict [str , Any ]] = None ):
450
+ """
451
+ Predict method is not implemented for our use case.
452
+
453
+ The CompositeModel class is designed to store and load multiple artifacts,
454
+ and the predict method is not required for this functionality.
455
+ """
456
+ raise NotImplementedError ("The predict method is not implemented for CompositeModel." )
0 commit comments