Skip to content

Commit 4d05d95

Browse files
authored
Implement: multiple save and load for mlflow registry (#416)
1. Implemented save_multiple and load_multiple for mlflow registry 2. Test cases for implementation. --------- Signed-off-by: Leila Wang <[email protected]>
1 parent b18b2d2 commit 4d05d95

File tree

5 files changed

+426
-10
lines changed

5 files changed

+426
-10
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ target/
7878
#mlflow
7979
/.mlruns
8080
*.db
81+
mlruns/
82+
mlartifacts/
8183

8284
# Jupyter Notebook
8385
.ipynb_checkpoints
@@ -169,4 +171,7 @@ cython_debug/
169171
# Mac related
170172
*.DS_Store
171173

174+
# vscode
175+
.vscode/
176+
172177
.python-version

numalogic/registry/mlflow_registry.py

+121-5
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
from enum import Enum
1616
from typing import Optional, Any
1717

18-
import mlflow.pyfunc
19-
import mlflow.pytorch
20-
import mlflow.sklearn
18+
import mlflow
2119
from mlflow.entities.model_registry import ModelVersion
22-
from mlflow.exceptions import RestException
20+
from mlflow.exceptions import RestException, MlflowException
2321
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
2422
from mlflow.tracking import MlflowClient
2523

@@ -187,6 +185,43 @@ def load(
187185
self._save_in_cache(model_key, artifact_data)
188186
return artifact_data
189187

188+
def load_multiple(
189+
self,
190+
skeys: KEYS,
191+
dkeys: KEYS,
192+
) -> Optional[ArtifactData]:
193+
"""
194+
Load multiple artifacts from the registry for pyfunc models.
195+
Args:
196+
skeys (KEYS): The source keys of the artifacts to load.
197+
dkeys: dynamic key fields as list/tuple of strings.
198+
199+
Returns
200+
-------
201+
Optional[ArtifactData]: The loaded ArtifactData object if available otherwise None.
202+
ArtifactData should contain a dictionary of artifacts.
203+
"""
204+
loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc")
205+
if loaded_model is None:
206+
return None
207+
208+
try:
209+
unwrapped_composite_model = loaded_model.artifact.unwrap_python_model()
210+
except MlflowException as e:
211+
raise TypeError("The loaded model is not a valid pyfunc Python model.") from e
212+
except AttributeError:
213+
_LOGGER.exception("The loaded model does not have an unwrap_python_model method")
214+
return None
215+
except Exception:
216+
_LOGGER.exception("Unexpected error occurred while unwrapping python model.")
217+
return None
218+
219+
return ArtifactData(
220+
artifact=unwrapped_composite_model.dict_artifacts,
221+
metadata=loaded_model.metadata,
222+
extras=loaded_model.extras,
223+
)
224+
190225
@staticmethod
191226
def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None:
192227
if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST:
@@ -225,7 +260,10 @@ def save(
225260
handler = self.handler_from_type(artifact_type)
226261
try:
227262
mlflow.start_run(run_id=run_id)
228-
handler.log_model(artifact, "model", registered_model_name=model_key)
263+
if artifact_type == "pyfunc":
264+
handler.log_model("model", python_model=artifact, registered_model_name=model_key)
265+
else:
266+
handler.log_model(artifact, "model", registered_model_name=model_key)
229267
if metadata:
230268
mlflow.log_params(metadata)
231269
model_version = self.transition_stage(skeys=skeys, dkeys=dkeys)
@@ -238,6 +276,42 @@ def save(
238276
finally:
239277
mlflow.end_run()
240278

279+
def save_multiple(
280+
self,
281+
skeys: KEYS,
282+
dkeys: KEYS,
283+
dict_artifacts: dict[str, artifact_t],
284+
**metadata: META_VT,
285+
) -> Optional[ModelVersion]:
286+
"""
287+
Saves multiple artifacts into mlflow registry. The last save stores all the
288+
artifact versions in the metadata.
289+
290+
Args:
291+
----
292+
skeys (KEYS): Static key fields as a list or tuple of strings.
293+
dkeys (KEYS): Dynamic key fields as a list or tuple of strings.
294+
dict_artifacts (dict[str, artifact_t]): Dictionary of artifacts to save.
295+
**metadata (META_VT): Additional metadata to be saved with the artifacts.
296+
297+
Returns
298+
-------
299+
Optional[ModelVersion]: An instance of the MLflow ModelVersion.
300+
301+
"""
302+
if len(dict_artifacts) == 1:
303+
_LOGGER.warning(
304+
"Only one artifact present in dict_artifacts. Saving directly is recommended."
305+
)
306+
multiple_artifacts = CompositeModel(skeys=skeys, dict_artifacts=dict_artifacts, **metadata)
307+
return self.save(
308+
skeys=skeys,
309+
dkeys=dkeys,
310+
artifact=multiple_artifacts,
311+
artifact_type="pyfunc",
312+
**metadata,
313+
)
314+
241315
@staticmethod
242316
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
243317
"""Returns whether the given artifact is stale or not, i.e. if
@@ -338,3 +412,45 @@ def __load_artifacts(
338412
version_info.version,
339413
)
340414
return model, metadata
415+
416+
417+
class CompositeModel(mlflow.pyfunc.PythonModel):
418+
"""A composite model that represents multiple artifacts.
419+
420+
This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load
421+
multiple artifacts in the MLflow registry. It provides a convenient way to manage and
422+
organize multiple artifacts associated with a single model.
423+
424+
Args:
425+
skeys (KEYS): The static keys of the artifacts.
426+
dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to
427+
`KeyedArtifact` objects.
428+
**metadata (META_VT): Additional metadata associated with the artifacts.
429+
430+
Methods
431+
-------
432+
predict: Not implemented for our use case.
433+
434+
Attributes
435+
----------
436+
skeys (KEYS): The static keys of the artifacts.
437+
dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to
438+
`KeyedArtifact` objects.
439+
metadata (META_VT): Additional metadata associated with the artifacts.
440+
"""
441+
442+
__slots__ = ("skeys", "dict_artifacts", "metadata")
443+
444+
def __init__(self, skeys: KEYS, dict_artifacts: dict[str, artifact_t], **metadata: META_VT):
445+
self.skeys = skeys
446+
self.dict_artifacts = dict_artifacts
447+
self.metadata = metadata
448+
449+
def predict(self, context, model_input, params: Optional[dict[str, Any]] = None):
450+
"""
451+
Predict method is not implemented for our use case.
452+
453+
The CompositeModel class is designed to store and load multiple artifacts,
454+
and the predict method is not required for this functionality.
455+
"""
456+
raise NotImplementedError("The predict method is not implemented for CompositeModel.")

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "numalogic"
3-
version = "0.13.2"
3+
version = "0.13.3"
44
description = "Collection of operational Machine Learning models and tools."
55
authors = ["Numalogic Developers"]
66
packages = [{ include = "numalogic" }]

tests/registry/_mlflow_utils.py

+124
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
from mlflow.store.entities import PagedList
1010
from sklearn.preprocessing import StandardScaler
1111
from torch import tensor
12+
from mlflow.models import Model
1213

14+
from numalogic.models.autoencoder.variants.vanilla import VanillaAE
1315
from numalogic.models.threshold import StdDevThreshold
16+
from numalogic.registry.mlflow_registry import CompositeModel
1417

1518

1619
def create_model():
@@ -135,6 +138,103 @@ def mock_log_model_sklearn(*_, **__):
135138
)
136139

137140

141+
def mock_log_model_pyfunc(*_, **__):
142+
return ModelInfo(
143+
artifact_path="model",
144+
flavors={
145+
"pyfunc": {"model_data": "data", "pyfunc_version": "1.11.0", "code": None},
146+
"python_function": {
147+
"pickle_module_name": "mlflow.pyfunc.pickle_module",
148+
"loader_module": "mlflow.pyfunc",
149+
"python_version": "3.8.5",
150+
"data": "data",
151+
"env": "conda.yaml",
152+
},
153+
},
154+
model_uri="runs:/a7c0b376530b40d7b23e6ce2081c899c/model",
155+
model_uuid="a7c0b376530b40d7b23e6ce2081c899c",
156+
run_id="a7c0b376530b40d7b23e6ce2081c899c",
157+
saved_input_example_info=None,
158+
signature_dict=None,
159+
utc_time_created="2022-05-23 22:35:59.557372",
160+
mlflow_version="2.0.1",
161+
signature=None,
162+
)
163+
164+
165+
def mock_load_model_pyfunc(*_, **__):
166+
artifact_path = "model"
167+
flavors = {
168+
"python_function": {
169+
"cloudpickle_version": "3.0.0",
170+
"code": None,
171+
"env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"},
172+
"loader_module": "mlflow.pyfunc.model",
173+
"python_model": "python_model.pkl",
174+
"python_version": "3.10.14",
175+
"streamable": False,
176+
}
177+
}
178+
model_size_bytes = 8912
179+
model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc"
180+
run_id = "7e85a3fa46d44e668c840f3dddc909c3"
181+
utc_time_created = "2024-09-18 17:12:41.501209"
182+
model = Model(
183+
artifact_path=artifact_path,
184+
flavors=flavors,
185+
model_size_bytes=model_size_bytes,
186+
model_uuid=model_uuid,
187+
run_id=run_id,
188+
utc_time_created=utc_time_created,
189+
mlflow_version="2.16.0",
190+
)
191+
return mlflow.pyfunc.PyFuncModel(
192+
model_meta=model,
193+
model_impl=TestObject(
194+
python_model=CompositeModel(
195+
skeys=["error"],
196+
dict_artifacts={
197+
"inference": VanillaAE(10),
198+
"precrocessing": StandardScaler(),
199+
"threshold": StdDevThreshold(),
200+
},
201+
**{"learning_rate": 0.01},
202+
)
203+
),
204+
)
205+
206+
207+
def mock_load_model_pyfunc_type_error(*_, **__):
208+
artifact_path = "model"
209+
flavors = {
210+
"python_function": {
211+
"cloudpickle_version": "3.0.0",
212+
"code": None,
213+
"env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"},
214+
"loader_module": "mlflow.pytorch.model",
215+
"python_model": "python_model.pkl",
216+
"python_version": "3.10.14",
217+
"streamable": False,
218+
}
219+
}
220+
model_size_bytes = 8912
221+
model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc"
222+
run_id = "7e85a3fa46d44e668c840f3dddc909c3"
223+
utc_time_created = "2024-09-18 17:12:41.501209"
224+
model = Model(
225+
artifact_path=artifact_path,
226+
flavors=flavors,
227+
model_size_bytes=model_size_bytes,
228+
model_uuid=model_uuid,
229+
run_id=run_id,
230+
utc_time_created=utc_time_created,
231+
mlflow_version="2.16.0",
232+
)
233+
return mlflow.pyfunc.PyFuncModel(
234+
model_meta=model, model_impl=mlflow.pytorch._PyTorchWrapper(VanillaAE(10), device="cpu")
235+
)
236+
237+
138238
def mock_transition_stage(*_, **__):
139239
return ModelVersion(
140240
creation_timestamp=1653402941169,
@@ -303,6 +403,25 @@ def return_sklearn_rundata():
303403
)
304404

305405

406+
def return_pyfunc_rundata():
407+
return Run(
408+
run_info=RunInfo(
409+
artifact_uri="mlflow-artifacts:/0/a7c0b376530b40d7b23e6ce2081c899c/artifacts/model",
410+
end_time=None,
411+
experiment_id="0",
412+
lifecycle_stage="active",
413+
run_id="a7c0b376530b40d7b23e6ce2081c899c",
414+
run_uuid="a7c0b376530b40d7b23e6ce2081c899c",
415+
start_time=1658788772612,
416+
status="RUNNING",
417+
user_id="lol",
418+
),
419+
run_data=RunData(
420+
metrics={}, tags={}, params=[mlflow.entities.Param("learning_rate", "0.01")]
421+
),
422+
)
423+
424+
306425
def return_pytorch_rundata_dict():
307426
return Run(
308427
run_info=RunInfo(
@@ -318,3 +437,8 @@ def return_pytorch_rundata_dict():
318437
),
319438
run_data=RunData(metrics={}, tags={}, params=[mlflow.entities.Param("lr", "0.001")]),
320439
)
440+
441+
442+
class TestObject(mlflow.pyfunc.PythonModel):
443+
def __init__(self, python_model):
444+
self.python_model = python_model

0 commit comments

Comments
 (0)