Skip to content

skops can't load TargetEncoder object #450

@bimhud

Description

@bimhud

SKOPS is unable to load sklearn-pipeline having TargetEncoder object.
The code below helps to re-produce the issue.

Simple code for a classifier using Target Encoder

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from category_encoders import TargetEncoder
from sklearn.metrics import accuracy_score

# Sample data
data = {
    'category': ['A', 'B', 'A', 'C', 'B', 'A', 'C', 'C', 'B', 'A'],
    'feature1': [10, 20, 10, 30, 20, 10, 30, 30, 20, 10],
    'feature2': [1, 2, 1, 3, 2, 1, 3, 3, 2, 1],
    'target': [0, 1, 0, 1, 1, 0, 1, 1, 1, 0]
}

# Create DataFrame
df = pd.DataFrame(data)

# Separate features and target
X = df.drop('target', axis=1)
y = df['target']

# Encode the category column using Target Encoder
encoder = TargetEncoder()
X['category'] = encoder.fit_transform(X['category'], y)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize and train the classifier
clf = RandomForestClassifier()
clf.fit(X_train, y_train)

# Make predictions
y_pred = clf.predict(X_test)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")

Using SKOPS to store into a file and get the unknown data types

from skops import io as sio
sio.dump(encoder, "test.sio")

unknown_type = sio.get_untrusted_types(file="test.sio")
print(unknown_type)

Output of unknown_type

['builtins.object',
 'category_encoders.ordinal.OrdinalEncoder',
 'category_encoders.target_encoder.TargetEncoder',
 'numpy.dtype',
 'pandas._libs.index.Int64Engine',
 'pandas._libs.index.ObjectEngine',
 'pandas._libs.internals.BlockValuesRefs',
 'pandas.core.indexes.base.Index',
 'pandas.core.internals.managers.SingleBlockManager',
 'pandas.core.series.Series']

Use SKOPS to load the file

sio.load("test.sio",trusted=unknown_type)

Obtained the error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[14], line 1
----> 1 sio.load("test.sio",trusted=unknown_type)

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_persist.py:152, in load(file, trusted)
    150     tree = get_tree(schema, load_context, trusted=trusted)
    151     audit_tree(tree)
--> 152     instance = tree.construct()
    154 return instance

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
    163 if self._constructed is not UNINITIALIZED:
    164     return self._constructed
--> 165 self._constructed = self._construct()
    166 return self._constructed

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:420, in ObjectNode._construct(self)
    416 if not self.children["attrs"]:
    417     # nothing more to do
    418     return instance
--> 420 attrs = self.children["attrs"].construct()
    421 if attrs is not None:
    422     if hasattr(instance, "__setstate__"):

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
    163 if self._constructed is not UNINITIALIZED:
    164     return self._constructed
--> 165 self._constructed = self._construct()
    166 return self._constructed

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:79, in DictNode._construct(self)
     77 key_types = self.children["key_types"].construct()
     78 for k_type, (key, val) in zip(key_types, self.children["content"].items()):
---> 79     content[k_type(key)] = val.construct()
     80 return content

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
    163 if self._constructed is not UNINITIALIZED:
    164     return self._constructed
--> 165 self._constructed = self._construct()
    166 return self._constructed

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:79, in DictNode._construct(self)
     77 key_types = self.children["key_types"].construct()
     78 for k_type, (key, val) in zip(key_types, self.children["content"].items()):
---> 79     content[k_type(key)] = val.construct()
     80 return content

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
    163 if self._constructed is not UNINITIALIZED:
    164     return self._constructed
--> 165 self._constructed = self._construct()
    166 return self._constructed

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:420, in ObjectNode._construct(self)
    416 if not self.children["attrs"]:
    417     # nothing more to do
    418     return instance
--> 420 attrs = self.children["attrs"].construct()
    421 if attrs is not None:
    422     if hasattr(instance, "__setstate__"):

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
    163 if self._constructed is not UNINITIALIZED:
    164     return self._constructed
--> 165 self._constructed = self._construct()
    166 return self._constructed

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:79, in DictNode._construct(self)
     77 key_types = self.children["key_types"].construct()
     78 for k_type, (key, val) in zip(key_types, self.children["content"].items()):
---> 79     content[k_type(key)] = val.construct()
     80 return content

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
    163 if self._constructed is not UNINITIALIZED:
    164     return self._constructed
--> 165 self._constructed = self._construct()
    166 return self._constructed

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:420, in ObjectNode._construct(self)
    416 if not self.children["attrs"]:
    417     # nothing more to do
    418     return instance
--> 420 attrs = self.children["attrs"].construct()
    421 if attrs is not None:
    422     if hasattr(instance, "__setstate__"):

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
    163 if self._constructed is not UNINITIALIZED:
    164     return self._constructed
--> 165 self._constructed = self._construct()
    166 return self._constructed

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:179, in TupleNode._construct(self)
    175 def _construct(self):
    176     # Returns a tuple or a namedtuple instance.
    178     cls = gettype(self.module_name, self.class_name)
--> 179     content = tuple(value.construct() for value in self.children["content"])
    181     if self.isnamedtuple(cls):
    182         return cls(*content)

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:179, in <genexpr>(.0)
    175 def _construct(self):
    176     # Returns a tuple or a namedtuple instance.
    178     cls = gettype(self.module_name, self.class_name)
--> 179     content = tuple(value.construct() for value in self.children["content"])
    181     if self.isnamedtuple(cls):
    182         return cls(*content)

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
    163 if self._constructed is not UNINITIALIZED:
    164     return self._constructed
--> 165 self._constructed = self._construct()
    166 return self._constructed

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:113, in ListNode._construct(self)
    111 def _construct(self):
    112     content_type = gettype(self.module_name, self.class_name)
--> 113     return content_type([item.construct() for item in self.children["content"]])

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:113, in <listcomp>(.0)
    111 def _construct(self):
    112     content_type = gettype(self.module_name, self.class_name)
--> 113     return content_type([item.construct() for item in self.children["content"]])

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
    163 if self._constructed is not UNINITIALIZED:
    164     return self._constructed
--> 165 self._constructed = self._construct()
    166 return self._constructed

File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:414, in ObjectNode._construct(self)
    408 cls = gettype(self.module_name, self.class_name)
    410 # Instead of simply constructing the instance, we use __new__, which
    411 # bypasses the __init__, and then we set the attributes. This solves the
    412 # issue of required init arguments. Note that the instance created here
    413 # might not be valid until all its attributes have been set below.
--> 414 instance = cls.__new__(cls)  # type: ignore
    416 if not self.children["attrs"]:
    417     # nothing more to do
    418     return instance

File ~/miniconda3/envs/project/lib/python3.10/site-packages/pandas/core/indexes/base.py:526, in Index.__new__(cls, data, dtype, copy, name, tupleize_cols)
    523         data = com.asarray_tuplesafe(data, dtype=_dtype_obj)
    525 elif is_scalar(data):
--> 526     raise cls._raise_scalar_data_error(data)
    527 elif hasattr(data, "__array__"):
    528     return cls(np.asarray(data), dtype=dtype, copy=copy, name=name)

File ~/miniconda3/envs/project/lib/python3.10/site-packages/pandas/core/indexes/base.py:5289, in Index._raise_scalar_data_error(cls, data)
   5284 @final
   5285 @classmethod
   5286 def _raise_scalar_data_error(cls, data):
   5287     # We return the TypeError so that we can raise it from the constructor
   5288     #  in order to keep mypy happy
-> 5289     raise TypeError(
   5290         f"{cls.__name__}(...) must be called with a collection of some "
   5291         f"kind, {repr(data) if not isinstance(data, np.generic) else str(data)} "
   5292         "was passed"
   5293     )

TypeError: Index(...) must be called with a collection of some kind, None was passed

ENV Python 3.10.14

dill==0.3.8
docker==7.1.0
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1720869315914/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
filelock==3.16.1
flatbuffers==24.3.25
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2024.6.1
gevent==24.2.1
geventhttpclient==2.0.2
google-pasta==0.2.0
greenlet==3.0.3
grpcio==1.64.1
huggingface-hub==0.26.2
humanfriendly==10.0
idna==3.7
imbalanced-learn==0.12.0
importlib-metadata==6.11.0
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1719582526268/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
jmespath==1.0.1
joblib==1.4.2
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1716472197302/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257277185/work
kiwisolver==1.4.5
lightgbm==4.5.0
matplotlib==3.9.1
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
numpy==1.26.4
oc==0.2.1
onnx==1.17.0
onnxconverter-common==1.14.0
onnxmltools==1.12.0
onnxruntime==1.16.3
oracledb==2.0.1
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work
pandas==2.2.1
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work
pathos==0.3.2
patsy==0.5.6
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
pillow==10.4.0
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work
pox==0.3.4
ppft==1.7.6.8
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work
protobuf==3.20.2
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1719274566094/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
pyathena==2.3.2
pycparser==2.22
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work
pyparsing==3.1.2
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work
python-rapidjson==1.17
pytz==2024.1
PyYAML==6.0.1
pyzmq @ file:///croot/pyzmq_1705605076900/work
referencing==0.35.1
requests==2.32.3
rpds-py==0.19.0
s3transfer==0.10.1
sagemaker==2.226.0
schema==0.7.7
scikit-learn==1.4.0
scipy==1.13.1
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
skl2onnx==1.17.0
skops==0.10.0
smdebug-rulesconfig==1.0.1
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
statsmodels==0.14.2
sympy==1.13.3
tabulate==0.9.0
tblib==3.0.0
tenacity==8.4.1
threadpoolctl==3.5.0
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1717722796999/work
tqdm==4.66.4
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work
tritonclient==2.46.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1717802530399/work
tzdata==2024.1
urllib3==2.2.2
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
xgboost==2.1.2
yarl==1.9.4
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1718013267051/work
zope.event==5.0
zope.interface==6.4.post2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions