-
Notifications
You must be signed in to change notification settings - Fork 576
[WIP] Remove legacy FIL #6603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Remove legacy FIL #6603
Conversation
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
// Allocate workspace and perform segmented reduce | ||
thrust::device_vector<kv_type> workspace2( | ||
params.n_rows + temp_storage_bytes / sizeof(kv_type) + 1); | ||
cub::DeviceSegmentedReduce::ArgMax( | ||
thrust::raw_pointer_cast(workspace2.data() + params.n_rows), temp_storage_bytes, | ||
workspace->begin(), workspace2.begin(), | ||
params.n_rows, offsets_it, offsets_it + 1); | ||
thrust::transform(workspace2.begin(), workspace2.begin() + params.n_rows, pred->begin(), | ||
[] __device__ (kv_type x) -> int { return x.key; }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note. The new FIL always returns probabilities for classifiers. If class outputs are desired, we need to perform argmax.
// Handle a degenerate tree with a childless root node | ||
if (aligned_data.inner_data.distant_offset == 0) { | ||
return offset_type{} + aligned_data.inner_data.distant_offset; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This provides a fix for a subtle memory bug. Given a degenerate tree with a single childless root node, the node struct is initialized with distant_offset = 0
. When the inference kernel accesses the next node at index condition + (distant_offset - 1)
, it tries to access index -1
, leading to out-of-bounds access.
I was able to discover the bug with the help of property-based gtests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I was first working on new FIL, we decided (consistent with legacy FIL) to just disallow degenerate trees. That always bothered me a little, so I'm glad to see it being handled properly now.
That being said, I would recommend handling this differently. I expect the performance impact of this change will be quite severe, and it also will not correctly handle degenerate trees.
In terms of the performance impact, this introduces an additional branch in the tight loop of tree evaluation, which will almost certainly add more instructions than we managed to eliminate moving to new FIL. It is also an inference-time check for a very rare condition that could be detected when the Treelite model is ingested. Finally (and probably least significantly), it prevents RVO through the addition of a separate return statement on a runtime-evaluated branch.
In terms of correctness, this change also will not account for undefined behavior here in the case of degenerate trees. Because this is a do-while loop instead of a while loop, we will access the unset member of a union in the body of the loop, leading to undefined behavior. I know, having tested it previously, that switching to a while loop has a serious negative performance impact, especially for shallow trees.
Instead, I would recommend the following:
- In this PR, introduce a check at the time of translation from Treelite for degenerate trees. If any appear, throw an exception.
- In a follow-on PR, update the translation logic to explicitly handle the case of degenerate trees. Promote degenerate trees to a trivial non-degenerate stump which returns the same output for both branches.
Step 2 might not even be necessary given how rare degenerate trees are, but it would be a nice thing to offer for completeness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking time to review! Yes, I was feeling nervous about inserting a new branch in the middle of a tight loop but was in a rush to fix an unexpected memory bug.
we decided (consistent with legacy FIL) to just disallow degenerate trees
I was not aware of this limitation, especially because the new FIL doesn't throw an exception when ingesting a degenerate tree.
Step 2 might not even be necessary given how rare degenerate trees are
Degenerate trees may be more common than we realize. I saw XGBoost produce degenerate trees when I set regularization knobs too high. Of course, users should be diligent in inspecting the trained model, but FIL should support degenerate trees for the sake of completeness.
Also, as it turns out, 37 out of the 180 unit tests from SG_RF_TEST
involve cases containing at least one degenerate tree. (This is how I found the issue with degenerate tree actually.) These unit tests arise from property-based testing, where training hyperparameters are set randomly to train a Random Forest model. So I think we should provide support for degenerate trees.
Since I am making substantial changes to the FIL code base anyway, I will go ahead and implement Step 2 in this feature branch.
For the public record, here's my plan:
- Complete all work on this feature branch, including the removal of legacy FIL and the built-in conversion of degenerate trees to valid stumps. This ensures that all changes work together.
- The feature branch contains lots of changes and is hard to review at once. So I will break it down to three separate pull requests (PRs):
- First PR will update the new FIL to correctly handle degenerate trees (by converting them to valid tree stumps).
- Second PR will remove the legacy FIL from the Python layer of cuML RF.
- Third PR will remove the legacy FIL from the C++ layer of cuML RF (+ gtest).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest commit implements the translation logic to convert all degenerate trees to tree stumps.
Closing this PR, I will submit three smaller PRs to replace it. |
Uh oh!
There was an error while loading. Please reload this page.