Skip to content

Commit 2d38a33

Browse files
authored
clear up confusing error message for non-contiguous labels in rf classifier [skip-ci] (#994)
+ add docstring note about this requirement Signed-off-by: Erik Ordentlich <[email protected]>
1 parent 4965ecf commit 2d38a33

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

python/src/spark_rapids_ml/classification.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,13 @@ class RandomForestClassifier(
418418
max_batch_size: int (default = 4096)
419419
Maximum number of nodes that can be processed in a given batch.
420420
421+
Notes
422+
-----
423+
The label column is required to be an integer in the range ``0, 1, ..., num_classes - 1``. Moreover, for fit() to succeed,
424+
all values in this range are required to be present in the input data and also each worker must receive the full range of values.
425+
If this is not the case, an error will be raised with possible work arounds being to remap the labels to the expected range,
426+
increase the number of very rare label occurrences in the input data, rerun with fewer workers, or shuffle the input data.
427+
421428
Examples
422429
--------
423430
>>> import numpy

python/src/spark_rapids_ml/tree.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -413,11 +413,13 @@ def _single_fit(rf: cuRf) -> Dict[str, Any]:
413413
# Fit a random forest model on the dataset (X, y)
414414
rf.fit(X, y, convert_dtype=False)
415415

416+
missing_labels_error_message = "A GPU worker did not receive all label values in the range 0, 1, ..., num_classes - 1, which is currently required. \
417+
Depending on the root cause, possible work arounds are to remap the labels to the required range, increase the number \
418+
of very rare label occurrences in the input data, rerun with fewer workers, or shuffle the input data."
419+
416420
if is_classification:
417421
if rf.classes_.max() != rf.n_classes_ - 1:
418-
raise RuntimeError(
419-
"A GPU worker did not receive all label values. Rerun with fewer workers or shuffle input data."
420-
)
422+
raise RuntimeError(missing_labels_error_message)
421423

422424
# serialized_model is Dictionary type
423425
serialized_model = rf._treelite_model_bytes
@@ -451,11 +453,9 @@ def _single_fit(rf: cuRf) -> Dict[str, Any]:
451453

452454
exc_str = traceback.format_exc()
453455
if "different num_class than the first model object" in exc_str:
454-
raise RuntimeError(
455-
"Some GPU workers did not receive all label values. Rerun with fewer workers or shuffle input data."
456-
)
456+
raise RuntimeError(missing_labels_error_message)
457457
else:
458-
raise err
458+
raise
459459

460460
final_model_bytes = pickle.dumps(_treelite_model_bytes)
461461
final_model = base64.b64encode(final_model_bytes).decode("utf-8")

0 commit comments

Comments
 (0)