Skip to content

Commit 6029961

Browse files
committed
Adds some more notes
1 parent 87953fe commit 6029961

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

hamilton/plugins/mlflow_extensions.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ class MLFLowSaver(DataSaver):
3333
@classmethod
3434
def applicable_types(cls) -> Collection[Type]:
3535
return [object]
36+
# we need an implementation for this. so return empty list.
37+
# this is a simpler way to specify what types this adapter can handle.
38+
39+
# @classmethod
40+
# def applies_to(cls, type_: Type[Type]) -> bool:
41+
# # This allows to override how we determine if a type is applicable.
42+
# str_type = type_.__module__
43+
# for model_type in ["sklearn", "pytorch", "tensorflow", "xgboost", "keras"]:
44+
# if str_type.startswith(model_type):
45+
# return True
46+
# return False
3647

3748
@classmethod
3849
def name(cls) -> str:
@@ -44,16 +55,16 @@ def save_data(self, model: object) -> Dict[str, Any]:
4455
with mlflow.start_run(run_name=self.run_name) as run:
4556
# Log the parameters used for the model fit
4657
# mlflow.log_params(data["params"])
58+
4759
# Log the error metrics that were calculated
4860
# mlflow.log_metrics(data["metrics"])
4961

5062
# Log an instance of the trained model for later use
5163
ml_logger = getattr(mlflow, self.model_type)
5264
model_info = ml_logger.log_model(
5365
model,
54-
# data["trained_model"],
5566
self.artifact_path,
56-
# input_example=data["input_example"],
67+
# input_example=data["input_example"], # or signature
5768
)
5869
return {
5970
"model_info": model_info.__dict__, # return some metadata

0 commit comments

Comments
 (0)