Skip to content

Commit 1b3c315

Browse files
authored
Merge branch 'main' into honestoblique_yuxin
2 parents 214d689 + ccccf9c commit 1b3c315

File tree

6 files changed

+54
-25
lines changed

6 files changed

+54
-25
lines changed

meson.build

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ project(
44
# Note that the git commit hash cannot be added dynamically here
55
# That only happens when importing from a git repository.
66
# See `treeple/__init__.py`
7-
version: '0.10.0.dev0',
7+
version: '0.10.3',
88
license: 'PolyForm Noncommercial 1.0.0',
99
meson_version: '>= 1.1.0',
1010
default_options: [

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ requires = [
1010
"setuptools<=65.5",
1111
"packaging",
1212
"Cython>=3.0.10",
13-
"scikit-learn>=1.5.0",
13+
"scikit-learn>=1.6.0",
1414
"scipy>=1.5.0",
1515
"numpy>=1.25; python_version>='3.9'"
1616
]
1717

1818
[project]
1919
name = "treeple"
20-
version = "0.10.0.dev0"
20+
version = "0.10.3"
2121
description = "Modern decision trees in Python"
2222
maintainers = [
2323
{name = "Neurodata", email = "adam.li@columbia.edu"}
@@ -52,7 +52,7 @@ include = [
5252
dependencies = [
5353
'numpy>=1.25.0',
5454
'scipy>=1.5.0',
55-
'scikit-learn>=1.5.0'
55+
'scikit-learn>=1.6.0'
5656
]
5757

5858
[project.optional-dependencies]
@@ -70,7 +70,7 @@ build = [
7070
'meson-python',
7171
'spin>=0.12',
7272
'doit',
73-
'scikit-learn>=1.5.0',
73+
'scikit-learn>=1.6.0',
7474
'Cython>=3.0.10',
7575
'ninja',
7676
'numpy>=1.25.0',

treeple/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import sys
66

7-
__version__ = "0.10.0dev0"
7+
__version__ = "0.10.3"
88
logger = logging.getLogger(__name__)
99

1010

treeple/ensemble/_honest_forest.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class HonestForestClassifier(ForestClassifier, ForestClassifierMixin):
182182
``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
183183
if ``sample_weight`` is passed.
184184
185-
bootstrap : bool, default=False
185+
bootstrap : bool, default=True
186186
Whether bootstrap samples are used when building trees. If False, the
187187
whole dataset is used to build each tree.
188188
@@ -270,24 +270,30 @@ class HonestForestClassifier(ForestClassifier, ForestClassifierMixin):
270270
Fraction of training samples used for estimates in the trees. The
271271
remaining samples will be used to learn the tree structure. A larger
272272
fraction creates shallower trees with lower variance estimates.
273-
273+
274274
honest_method : {"prune", "apply"}, default="prune"
275275
Method for enforcing honesty. If "prune", the tree is pruned to enforce
276276
honesty. If "apply", the tree is not pruned, but the leaf estimates are
277277
adjusted to enforce honesty.
278278
279+
kernel_method : bool, default=True
280+
Method for normalizing ``predict_proba`` posteriors by the number of
281+
samples in the leaf nodes across the forest. Contrary to the average of
282+
posteriors, the kernel method only normalizes the probabilities once.
283+
By default True.
284+
279285
tree_estimator : object, default=None
280286
Instantiated tree of type BaseDecisionTree from treeple.
281287
If None, then sklearn's DecisionTreeClassifier with default parameters will
282288
be used. Note that none of the parameters in ``tree_estimator`` need
283289
to be set. The parameters of the ``tree_estimator`` can be set using
284290
the ``tree_estimator_params`` keyword argument.
285291
286-
stratify : bool
292+
stratify : bool, default=True
287293
Whether or not to stratify sample when considering structure and leaf indices.
288294
This will also stratify samples when bootstrap sampling is used. For more
289295
information, see :func:`sklearn.utils.resample`.
290-
By default False.
296+
By default True.
291297
292298
**tree_estimator_params : dict
293299
Parameters to pass to the underlying base tree estimators.
@@ -462,12 +468,13 @@ def __init__(
462468
warm_start=False,
463469
class_weight=None,
464470
ccp_alpha=0.0,
465-
max_samples=None,
471+
max_samples=1.6,
466472
honest_prior="ignore",
467473
honest_fraction=0.5,
468-
honest_method="apply",
474+
honest_method="prune",
475+
kernel_method=True,
469476
tree_estimator=None,
470-
stratify=False,
477+
stratify=True,
471478
**tree_estimator_params,
472479
):
473480
super().__init__(
@@ -490,6 +497,7 @@ def __init__(
490497
"honest_prior",
491498
"honest_method",
492499
"stratify",
500+
"kernel_method",
493501
),
494502
bootstrap=bootstrap,
495503
oob_score=oob_score,
@@ -513,7 +521,7 @@ def __init__(
513521
self.honest_fraction = honest_fraction
514522
self.honest_prior = honest_prior
515523
self.honest_method = honest_method
516-
print(self.honest_method)
524+
self.kernel_method = kernel_method
517525
self.tree_estimator = tree_estimator
518526
self.stratify = stratify
519527
self._tree_estimator_params = tree_estimator_params

treeple/tree/_honest_tree.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ class HonestTreeClassifier(MetaEstimatorMixin, ClassifierMixin, BaseDecisionTree
186186
classes). If "empirical", the prior tree posterior is the relative
187187
class frequency in the voting subsample.
188188
189-
stratify : bool
189+
stratify : bool, default=True
190190
Whether or not to stratify sample when considering structure and leaf indices.
191-
By default False.
191+
By default True.
192192
193193
honest_method : {"apply", "prune"}, default="apply"
194194
Method to use for fitting the leaf nodes. If "apply", the leaf nodes
@@ -197,6 +197,12 @@ class frequency in the voting subsample.
197197
by pruning using the honest-set of data after the tree structure is built
198198
using the structure-set of data.
199199
200+
kernel_method : bool, default=False
201+
Method for normalizing ``predict_proba`` posteriors by the number of
202+
samples in the leaf nodes across the forest. Not applicalble to single
203+
honest trees.
204+
By default False.
205+
200206
**tree_estimator_params : dict
201207
Parameters to pass to the underlying base tree estimators.
202208
These must be parameters for ``tree_estimator``.
@@ -338,8 +344,9 @@ def __init__(
338344
monotonic_cst=None,
339345
honest_fraction=0.5,
340346
honest_prior="empirical",
341-
stratify=False,
347+
stratify=True,
342348
honest_method="apply",
349+
kernel_method=False,
343350
**tree_estimator_params,
344351
):
345352
self.tree_estimator = tree_estimator
@@ -361,6 +368,7 @@ def __init__(
361368
self.honest_prior = honest_prior
362369
self.stratify = stratify
363370
self.honest_method = honest_method
371+
self.kernel_method = kernel_method
364372

365373
# XXX: to enable this, we need to also reset the leaf node samples during `_set_leaf_nodes`
366374
self.store_leaf_values = False
@@ -876,9 +884,11 @@ class in a leaf.
876884

877885
if self.n_outputs_ == 1:
878886
proba = proba[:, : self._tree_n_classes_]
879-
# normalizer = proba.sum(axis=1)[:, np.newaxis]
880-
# normalizer[normalizer == 0.0] = 1.0
881-
# proba /= normalizer
887+
888+
if not self.kernel_method:
889+
normalizer = proba.sum(axis=1)[:, np.newaxis]
890+
normalizer[normalizer == 0.0] = 1.0
891+
proba /= normalizer
882892
proba = self._empty_leaf_correction(proba)
883893

884894
return proba
@@ -888,10 +898,13 @@ class in a leaf.
888898

889899
for k in range(self.n_outputs_):
890900
proba_k = proba[:, k, : self._tree_n_classes_[k]]
891-
normalizer = proba_k.sum(axis=1)[:, np.newaxis]
892-
# normalizer[normalizer == 0.0] = 1.0
893-
# proba_k /= normalizer
894-
# proba_k = self._empty_leaf_correction(proba_k, k)
901+
902+
if not self.kernel_method:
903+
normalizer = proba_k.sum(axis=1)[:, np.newaxis]
904+
normalizer[normalizer == 0.0] = 1.0
905+
proba_k /= normalizer
906+
proba_k = self._empty_leaf_correction(proba_k, k)
907+
895908
all_proba.append(proba_k)
896909

897910
return all_proba

treeple/tree/honesty/_honest_prune.pyx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ cdef class HonestPruner(Splitter):
152152
self.samples[current_end], self.samples[p]
153153
n_missing += 1
154154
current_end -= 1
155-
elif p > pos and (self.tree._compute_feature(X_ndarray, sample_idx, &self.tree.nodes[node_idx]) <= threshold):
155+
156+
# Leverage sklearn's forked API to compute the feature value at this split node
157+
# and then compare that to the corresponding threshold
158+
# Note: this enables the function to work w/ both axis-aligned and oblique splits.
159+
elif p > pos and (self.tree._compute_feature(X_ndarray, sample_idx, &self.tree.nodes[node_idx])<= threshold):
156160
self.samples[p], self.samples[pos] = \
157161
self.samples[pos], self.samples[p]
158162
pos += 1
@@ -367,8 +371,12 @@ cdef _honest_prune(
367371
split_is_degenerate = (
368372
pruner.n_left_samples() == 0 or pruner.n_right_samples() == 0
369373
)
374+
370375
is_leaf_in_origtree = child_l[node_idx] == _TREE_LEAF
376+
371377
if invalid_split or split_is_degenerate or is_leaf_in_origtree:
378+
# invalid_split or is_leaf_in_origtree:
379+
# or split_is_degenerate or is_leaf_in_origtree:
372380
# ... and child_r[node_idx] == _TREE_LEAF:
373381
#
374382
# 1) if node is not degenerate, that means there are still honest-samples in

0 commit comments

Comments
 (0)