-
Notifications
You must be signed in to change notification settings - Fork 30
Description
Describe the bug
LocalClassifierPerParentNode.predict_proba(X) prediction for a sample depends on the other samples in X
To Reproduce
One example could be the following:
def test_predict_proba_does_not_depend_on_samples_in_batch() -> None:
train_x = np.array([[0], [1]])
train_y = [[0, 0], [1, 1]]
classifier = LocalClassifierPerParentNode()
classifier.fit(train_x, train_y)
predict_all_take_first = classifier.predict_proba(train_x)[0]
predict_first_take_first = classifier.predict_proba(train_x[:1])[0]
np.testing.assert_allclose(predict_all_take_first, predict_first_take_first)
which results in:
E x: array([0.555353, 0.444647])
E y: array([1., 0.])
Expected behavior
classifier.predict_proba(train_x)[0] == classifier.predict_proba(train_x)[0]
Versions:
Package Version Editable project location
alabaster 0.7.16
attrs 24.2.0
babel 2.16.0
black 24.3.0
boto3 1.35.14
botocore 1.35.14
certifi 2024.8.30
cfgv 3.4.0
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
contourpy 1.3.0
coverage 7.6.1
cycler 0.12.1
distlib 0.3.8
docutils 0.17.1
filelock 3.16.0
flake8 4.0.1
fonttools 4.53.1
fsspec 2024.9.0
hiclass <redacted, 5.0.4, master@76f7953>
hypothesis 6.150.0
identify 2.6.0
idna 3.8
imagesize 1.4.1
importlib-metadata 8.4.0
iniconfig 2.0.0
jaraco-classes 3.4.0
jaraco-context 6.0.1
jaraco-functools 4.0.2
jinja2 3.1.4
jmespath 1.0.1
joblib 1.4.2
jsonschema 4.26.0
jsonschema-specifications 2025.9.1
keyring 25.3.0
kiwisolver 1.4.7
llvmlite 0.43.0
markdown-it-py 3.0.0
markupsafe 2.1.5
matplotlib 3.9.2
mccabe 0.6.1
mdurl 0.1.2
more-itertools 10.5.0
mpmath 1.3.0
msgpack 1.1.2
mypy-extensions 1.0.0
networkx 3.3
nh3 0.2.18
nodeenv 1.9.1
numba 0.60.0
numpy 1.26.4
packaging 24.1
pandas 1.4.2
pathspec 0.12.1
pillow 10.4.0
pipenv 2026.0.3
pkginfo 1.10.0
platformdirs 4.3.1
pluggy 1.5.0
pre-commit 2.20.0
protobuf 6.33.3
py 1.11.0
pycodestyle 2.8.0
pydocstyle 6.1.1
pyfakefs 5.6.0
pyflakes 2.4.0
pygments 2.18.0
pyparsing 3.1.4
pytest 7.1.2
pytest-cov 3.0.0
pytest-flake8 1.1.1
pytest-pydocstyle 2.3.0
python-dateutil 2.9.0.post0
pytz 2024.1
pyyaml 6.0.2
ray 2.53.0
readme-renderer 43.0
readthedocs-sphinx-search 0.1.2
referencing 0.37.0
requests 2.32.3
requests-toolbelt 1.0.0
rfc3986 2.0.0
rich 13.8.0
rpds-py 0.30.0
s3transfer 0.10.2
scikit-learn 1.6.0
scipy 1.11.4
setuptools 80.9.0
shap 0.44.1
six 1.16.0
slicer 0.0.7
snowballstemmer 2.2.0
sortedcontainers 2.4.0
sphinx 5.0.0
sphinx-code-tabs 0.5.3
sphinx-gallery 0.10.1
sphinx-rtd-theme 1.0.0
sphinxcontrib-applehelp 2.0.0
sphinxcontrib-devhelp 2.0.0
sphinxcontrib-htmlhelp 2.1.0
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-qthelp 2.0.0
sphinxcontrib-serializinghtml 2.0.0
sympy 1.13.2
threadpoolctl 3.5.0
toml 0.10.2
tomli 2.0.1
torch 2.4.1
tqdm 4.66.5
twine 5.1.1
typing-extensions 4.12.2
urllib3 2.2.2
virtualenv 20.26.4
xarray 2023.1.0
zipp 3.20.1