Skip to content

Conversation

@eordentlich
Copy link
Collaborator

  • add docstring note about this requirement

…sifier and add docstring note about this requirement

Signed-off-by: Erik Ordentlich <[email protected]>
@eordentlich eordentlich changed the title clear up confusing error message for non-contiguous labels in rf classifier clear up confusing error message for non-contiguous labels in rf classifier [skip-ci] Nov 4, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Improved error messaging for RandomForestClassifier when GPU workers don't receive all required label values. The changes consolidate duplicate error messages and provide clearer guidance on workarounds, while adding comprehensive docstring documentation about the label requirements.

Key Changes

  • Consolidated error messages into a single, reusable string variable in tree.py:416-418
  • Updated error message to explicitly state the requirement for labels in range 0, 1, ..., num_classes - 1
  • Added detailed docstring notes in classification.py:421-426 explaining label requirements and workarounds
  • Error messages now provide actionable workarounds: remap labels, increase rare label occurrences, reduce workers, or shuffle input data
  • Removed redundant error message text from two locations in the code

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes are purely documentation and error message improvements without any logic modifications. The error handling paths remain identical, only the message text is improved for clarity and consolidated for maintainability. No behavioral changes are introduced.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
python/src/spark_rapids_ml/classification.py 5/5 Added docstring notes clarifying label requirements for RandomForestClassifier
python/src/spark_rapids_ml/tree.py 5/5 Improved error messages for missing label values with actionable workarounds

Sequence Diagram

sequenceDiagram
    participant User
    participant RandomForestClassifier
    participant Worker as GPU Worker
    participant TreeLite
    
    User->>RandomForestClassifier: fit(data)
    RandomForestClassifier->>Worker: Distribute data across workers
    
    alt Classification Task
        Worker->>Worker: rf.fit(X, y)
        Worker->>Worker: Check if rf.classes_.max() != rf.n_classes_ - 1
        
        alt Missing Labels Detected
            Worker-->>User: RuntimeError: "A GPU worker did not receive all label values..."
        else All Labels Present
            Worker->>Worker: Serialize model
            Worker->>TreeLite: Create TreeLite models
            
            alt Worker 0 Concatenates Models
                TreeLite->>TreeLite: Model.concatenate(all_models)
                
                alt Concatenation Fails (different num_class)
                    TreeLite-->>User: RuntimeError: "A GPU worker did not receive all label values..."
                else Concatenation Succeeds
                    TreeLite->>RandomForestClassifier: Return final model
                    RandomForestClassifier->>User: Model trained successfully
                end
            end
        end
    end
Loading

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@eordentlich eordentlich merged commit 2d38a33 into NVIDIA:main Nov 5, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants