Skip to content

Commit a34755a

Browse files
authored
Support __len__ in ensemble estimators (#6468)
`sklearn` ensemble estimators are valid sequences of estimators. Supporting `__getitem__` and `__iter__` is _hard_ with our current implementation, but `__len__` is easy and lets more of the sklearn compatiblity tests pass. Fixes #6465. Authors: - Jim Crist-Harif (https://github.com/jcrist) Approvers: - Simon Adorf (https://github.com/csadorf) URL: #6468
1 parent 84ebf2f commit a34755a

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

python/cuml/cuml/ensemble/randomforest_common.pyx

+4
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@ class BaseRandomForestModel(UniversalBase):
192192
self.treelite_serialized_model = None
193193
self._cpu_model_class_lock = threading.RLock()
194194

195+
def __len__(self):
196+
"""Return the number of estimators in the ensemble."""
197+
return self.n_estimators
198+
195199
def _get_max_feat_val(self) -> float:
196200
if isinstance(self.max_features, int):
197201
return self.max_features/self.n_cols

python/cuml/cuml/tests/test_random_forest.py

+7
Original file line numberDiff line numberDiff line change
@@ -1447,3 +1447,10 @@ def test_rf_predict_returns_int():
14471447
clf = cuml.ensemble.RandomForestClassifier().fit(X, y)
14481448
pred = clf.predict(X)
14491449
assert pred.dtype == np.int64
1450+
1451+
1452+
def test_ensemble_estimator_length():
1453+
X, y = make_classification()
1454+
clf = cuml.ensemble.RandomForestClassifier(n_estimators=3)
1455+
clf.fit(X, y)
1456+
assert len(clf) == 3

0 commit comments

Comments
 (0)