@@ -33,6 +33,17 @@ class MLFLowSaver(DataSaver):
33
33
@classmethod
34
34
def applicable_types (cls ) -> Collection [Type ]:
35
35
return [object ]
36
+ # we need an implementation for this. so return empty list.
37
+ # this is a simpler way to specify what types this adapter can handle.
38
+
39
+ # @classmethod
40
+ # def applies_to(cls, type_: Type[Type]) -> bool:
41
+ # # This allows to override how we determine if a type is applicable.
42
+ # str_type = type_.__module__
43
+ # for model_type in ["sklearn", "pytorch", "tensorflow", "xgboost", "keras"]:
44
+ # if str_type.startswith(model_type):
45
+ # return True
46
+ # return False
36
47
37
48
@classmethod
38
49
def name (cls ) -> str :
@@ -44,16 +55,16 @@ def save_data(self, model: object) -> Dict[str, Any]:
44
55
with mlflow .start_run (run_name = self .run_name ) as run :
45
56
# Log the parameters used for the model fit
46
57
# mlflow.log_params(data["params"])
58
+
47
59
# Log the error metrics that were calculated
48
60
# mlflow.log_metrics(data["metrics"])
49
61
50
62
# Log an instance of the trained model for later use
51
63
ml_logger = getattr (mlflow , self .model_type )
52
64
model_info = ml_logger .log_model (
53
65
model ,
54
- # data["trained_model"],
55
66
self .artifact_path ,
56
- # input_example=data["input_example"],
67
+ # input_example=data["input_example"], # or signature
57
68
)
58
69
return {
59
70
"model_info" : model_info .__dict__ , # return some metadata
0 commit comments