You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -1 +1 @@-gzip compressed data, was "flytekitplugins-onnxscikitlearn-1.9.0a0.tar", last modified: Thu Jul 20 18:58:23 2023, max compression+gzip compressed data, was "flytekitplugins-onnxscikitlearn-1.9.1.tar", last modified: Mon Aug 28 16:43:11 2023, max compression
@@ -1,30 +1,29 @@
from __future__ import annotations
import inspect
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import skl2onnx.common.data_types
-from dataclasses_json import dataclass_json+from dataclasses_json import DataClassJsonMixin
from skl2onnx import convert_sklearn
from sklearn.base import BaseEstimator
from typing_extensions import Annotated, get_args, get_origin
from flytekit import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.file import ONNXFile
-@dataclass_json
@dataclass
-class ScikitLearn2ONNXConfig:+class ScikitLearn2ONNXConfig(DataClassJsonMixin):
"""
ScikitLearn2ONNXConfig is the config used during the scikitlearn to ONNX conversion.
Args:
initial_types: The types of the inputs to the model.
name: The name of the graph in the produced ONNX model.
doc_string: A string attached onto the produced ONNX model.
@@ -67,17 +66,16 @@
validate_final_types = [
True for item in self.final_types if item in inspect.getmembers(skl2onnx.common.data_types)
]
if not all(validate_final_types):
raise ValueError("All types in final_types must be in skl2onnx.common.data_types")
-@dataclass_json
@dataclass
-class ScikitLearn2ONNX:+class ScikitLearn2ONNX(DataClassJsonMixin):
model: BaseEstimator = field(default=None)
def extract_config(t: Type[ScikitLearn2ONNX]) -> Tuple[Type[ScikitLearn2ONNX], ScikitLearn2ONNXConfig]:
config = None
if get_origin(t) is Annotated: