Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,13 @@ class RandomForestClassifier(
max_batch_size: int (default = 4096)
Maximum number of nodes that can be processed in a given batch.
Notes
-----
The label column is required to be an integer in the range ``0, 1, ..., num_classes - 1``. Moreover, for fit() to succeed,
all values in this range are required to be present in the input data and also each worker must receive the full range of values.
If this is not the case, an error will be raised with possible work arounds being to remap the labels to the expected range,
increase the number of very rare label occurrences in the input data, rerun with fewer workers, or shuffle the input data.
Examples
--------
>>> import numpy
Expand Down
14 changes: 7 additions & 7 deletions python/src/spark_rapids_ml/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,11 +413,13 @@ def _single_fit(rf: cuRf) -> Dict[str, Any]:
# Fit a random forest model on the dataset (X, y)
rf.fit(X, y, convert_dtype=False)

missing_labels_error_message = "A GPU worker did not receive all label values in the range 0, 1, ..., num_classes - 1, which is currently required. \
Depending on the root cause, possible work arounds are to remap the labels to the required range, increase the number \
of very rare label occurrences in the input data, rerun with fewer workers, or shuffle the input data."

if is_classification:
if rf.classes_.max() != rf.n_classes_ - 1:
raise RuntimeError(
"A GPU worker did not receive all label values. Rerun with fewer workers or shuffle input data."
)
raise RuntimeError(missing_labels_error_message)

# serialized_model is Dictionary type
serialized_model = rf._treelite_model_bytes
Expand Down Expand Up @@ -451,11 +453,9 @@ def _single_fit(rf: cuRf) -> Dict[str, Any]:

exc_str = traceback.format_exc()
if "different num_class than the first model object" in exc_str:
raise RuntimeError(
"Some GPU workers did not receive all label values. Rerun with fewer workers or shuffle input data."
)
raise RuntimeError(missing_labels_error_message)
else:
raise err
raise

final_model_bytes = pickle.dumps(_treelite_model_bytes)
final_model = base64.b64encode(final_model_bytes).decode("utf-8")
Expand Down