-
Notifications
You must be signed in to change notification settings - Fork 62
Open
Description
SKOPS is unable to load sklearn-pipeline having TargetEncoder object.
The code below helps to re-produce the issue.
Simple code for a classifier using Target Encoder
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from category_encoders import TargetEncoder
from sklearn.metrics import accuracy_score
# Sample data
data = {
'category': ['A', 'B', 'A', 'C', 'B', 'A', 'C', 'C', 'B', 'A'],
'feature1': [10, 20, 10, 30, 20, 10, 30, 30, 20, 10],
'feature2': [1, 2, 1, 3, 2, 1, 3, 3, 2, 1],
'target': [0, 1, 0, 1, 1, 0, 1, 1, 1, 0]
}
# Create DataFrame
df = pd.DataFrame(data)
# Separate features and target
X = df.drop('target', axis=1)
y = df['target']
# Encode the category column using Target Encoder
encoder = TargetEncoder()
X['category'] = encoder.fit_transform(X['category'], y)
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Initialize and train the classifier
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
# Make predictions
y_pred = clf.predict(X_test)
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")
Using SKOPS to store into a file and get the unknown data types
from skops import io as sio
sio.dump(encoder, "test.sio")
unknown_type = sio.get_untrusted_types(file="test.sio")
print(unknown_type)Output of unknown_type
['builtins.object',
'category_encoders.ordinal.OrdinalEncoder',
'category_encoders.target_encoder.TargetEncoder',
'numpy.dtype',
'pandas._libs.index.Int64Engine',
'pandas._libs.index.ObjectEngine',
'pandas._libs.internals.BlockValuesRefs',
'pandas.core.indexes.base.Index',
'pandas.core.internals.managers.SingleBlockManager',
'pandas.core.series.Series']Use SKOPS to load the file
sio.load("test.sio",trusted=unknown_type)Obtained the error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[14], line 1
----> 1 sio.load("test.sio",trusted=unknown_type)
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_persist.py:152, in load(file, trusted)
150 tree = get_tree(schema, load_context, trusted=trusted)
151 audit_tree(tree)
--> 152 instance = tree.construct()
154 return instance
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:420, in ObjectNode._construct(self)
416 if not self.children["attrs"]:
417 # nothing more to do
418 return instance
--> 420 attrs = self.children["attrs"].construct()
421 if attrs is not None:
422 if hasattr(instance, "__setstate__"):
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:79, in DictNode._construct(self)
77 key_types = self.children["key_types"].construct()
78 for k_type, (key, val) in zip(key_types, self.children["content"].items()):
---> 79 content[k_type(key)] = val.construct()
80 return content
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:79, in DictNode._construct(self)
77 key_types = self.children["key_types"].construct()
78 for k_type, (key, val) in zip(key_types, self.children["content"].items()):
---> 79 content[k_type(key)] = val.construct()
80 return content
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:420, in ObjectNode._construct(self)
416 if not self.children["attrs"]:
417 # nothing more to do
418 return instance
--> 420 attrs = self.children["attrs"].construct()
421 if attrs is not None:
422 if hasattr(instance, "__setstate__"):
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:79, in DictNode._construct(self)
77 key_types = self.children["key_types"].construct()
78 for k_type, (key, val) in zip(key_types, self.children["content"].items()):
---> 79 content[k_type(key)] = val.construct()
80 return content
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:420, in ObjectNode._construct(self)
416 if not self.children["attrs"]:
417 # nothing more to do
418 return instance
--> 420 attrs = self.children["attrs"].construct()
421 if attrs is not None:
422 if hasattr(instance, "__setstate__"):
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:179, in TupleNode._construct(self)
175 def _construct(self):
176 # Returns a tuple or a namedtuple instance.
178 cls = gettype(self.module_name, self.class_name)
--> 179 content = tuple(value.construct() for value in self.children["content"])
181 if self.isnamedtuple(cls):
182 return cls(*content)
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:179, in <genexpr>(.0)
175 def _construct(self):
176 # Returns a tuple or a namedtuple instance.
178 cls = gettype(self.module_name, self.class_name)
--> 179 content = tuple(value.construct() for value in self.children["content"])
181 if self.isnamedtuple(cls):
182 return cls(*content)
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:113, in ListNode._construct(self)
111 def _construct(self):
112 content_type = gettype(self.module_name, self.class_name)
--> 113 return content_type([item.construct() for item in self.children["content"]])
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:113, in <listcomp>(.0)
111 def _construct(self):
112 content_type = gettype(self.module_name, self.class_name)
--> 113 return content_type([item.construct() for item in self.children["content"]])
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:414, in ObjectNode._construct(self)
408 cls = gettype(self.module_name, self.class_name)
410 # Instead of simply constructing the instance, we use __new__, which
411 # bypasses the __init__, and then we set the attributes. This solves the
412 # issue of required init arguments. Note that the instance created here
413 # might not be valid until all its attributes have been set below.
--> 414 instance = cls.__new__(cls) # type: ignore
416 if not self.children["attrs"]:
417 # nothing more to do
418 return instance
File ~/miniconda3/envs/project/lib/python3.10/site-packages/pandas/core/indexes/base.py:526, in Index.__new__(cls, data, dtype, copy, name, tupleize_cols)
523 data = com.asarray_tuplesafe(data, dtype=_dtype_obj)
525 elif is_scalar(data):
--> 526 raise cls._raise_scalar_data_error(data)
527 elif hasattr(data, "__array__"):
528 return cls(np.asarray(data), dtype=dtype, copy=copy, name=name)
File ~/miniconda3/envs/project/lib/python3.10/site-packages/pandas/core/indexes/base.py:5289, in Index._raise_scalar_data_error(cls, data)
5284 @final
5285 @classmethod
5286 def _raise_scalar_data_error(cls, data):
5287 # We return the TypeError so that we can raise it from the constructor
5288 # in order to keep mypy happy
-> 5289 raise TypeError(
5290 f"{cls.__name__}(...) must be called with a collection of some "
5291 f"kind, {repr(data) if not isinstance(data, np.generic) else str(data)} "
5292 "was passed"
5293 )
TypeError: Index(...) must be called with a collection of some kind, None was passedENV Python 3.10.14
dill==0.3.8
docker==7.1.0
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1720869315914/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
filelock==3.16.1
flatbuffers==24.3.25
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2024.6.1
gevent==24.2.1
geventhttpclient==2.0.2
google-pasta==0.2.0
greenlet==3.0.3
grpcio==1.64.1
huggingface-hub==0.26.2
humanfriendly==10.0
idna==3.7
imbalanced-learn==0.12.0
importlib-metadata==6.11.0
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1719582526268/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
jmespath==1.0.1
joblib==1.4.2
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1716472197302/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257277185/work
kiwisolver==1.4.5
lightgbm==4.5.0
matplotlib==3.9.1
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
numpy==1.26.4
oc==0.2.1
onnx==1.17.0
onnxconverter-common==1.14.0
onnxmltools==1.12.0
onnxruntime==1.16.3
oracledb==2.0.1
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work
pandas==2.2.1
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work
pathos==0.3.2
patsy==0.5.6
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
pillow==10.4.0
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work
pox==0.3.4
ppft==1.7.6.8
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work
protobuf==3.20.2
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1719274566094/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
pyathena==2.3.2
pycparser==2.22
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work
pyparsing==3.1.2
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work
python-rapidjson==1.17
pytz==2024.1
PyYAML==6.0.1
pyzmq @ file:///croot/pyzmq_1705605076900/work
referencing==0.35.1
requests==2.32.3
rpds-py==0.19.0
s3transfer==0.10.1
sagemaker==2.226.0
schema==0.7.7
scikit-learn==1.4.0
scipy==1.13.1
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
skl2onnx==1.17.0
skops==0.10.0
smdebug-rulesconfig==1.0.1
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
statsmodels==0.14.2
sympy==1.13.3
tabulate==0.9.0
tblib==3.0.0
tenacity==8.4.1
threadpoolctl==3.5.0
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1717722796999/work
tqdm==4.66.4
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work
tritonclient==2.46.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1717802530399/work
tzdata==2024.1
urllib3==2.2.2
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
xgboost==2.1.2
yarl==1.9.4
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1718013267051/work
zope.event==5.0
zope.interface==6.4.post2Metadata
Metadata
Assignees
Labels
No labels