Skip to content

Commit 86cb6ba

Browse files
authored
Compatibility fix for XGBoost 3.2 (#655)
* Compatibility fix for XGBoost 3.2 * Address warning from XGBoost * Check for leaf_weights field
1 parent 411113b commit 86cb6ba

3 files changed

Lines changed: 23 additions & 17 deletions

File tree

src/model_loader/detail/xgboost_json/delegated_handler.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ bool RegTreeHandler::StartArray() {
348348
return (push_key_handler<ArrayHandler<float>>("loss_changes", loss_changes)
349349
|| push_key_handler<ArrayHandler<float>>("sum_hessian", sum_hessian)
350350
|| push_key_handler<ArrayHandler<float>>("base_weights", base_weights)
351+
|| push_key_handler<ArrayHandler<float>>("leaf_weights", leaf_weights)
351352
|| push_key_handler<ArrayHandler<int>>("categories_segments", categories_segments)
352353
|| push_key_handler<ArrayHandler<int>>("categories_sizes", categories_sizes)
353354
|| push_key_handler<ArrayHandler<int>>("categories_nodes", categories_nodes)
@@ -390,10 +391,8 @@ bool RegTreeHandler::EndObject() {
390391
if (output.size_leaf_vector == 0) {
391392
output.size_leaf_vector = 1; // In XGBoost, size_leaf_vector=0 indicates a scalar output
392393
}
393-
if (num_nodes * output.size_leaf_vector != base_weights.size()) {
394-
TREELITE_LOG(ERROR) << "Field base_weights has an incorrect dimension. Expected: "
395-
<< (num_nodes * output.size_leaf_vector)
396-
<< ", Actual: " << base_weights.size();
394+
if (output.size_leaf_vector != 1 && leaf_weights.empty()) {
395+
TREELITE_LOG(ERROR) << "Field leaf_weights must be provided for multi-target trees.";
397396
return false;
398397
}
399398
if (static_cast<std::size_t>(num_nodes) != left_children.size()) {
@@ -440,9 +439,10 @@ bool RegTreeHandler::EndObject() {
440439
if (size_leaf_vector > 1) {
441440
// Vector output
442441
std::vector<float> leafvec(size_leaf_vector);
443-
std::transform(&base_weights[node_id * size_leaf_vector],
444-
&base_weights[(node_id + 1) * size_leaf_vector], leafvec.begin(),
445-
[](float e) { return static_cast<float>(e); });
442+
auto leaf_id = right_children[node_id];
443+
TREELITE_CHECK_NE(leaf_id, -1) << "Expected a leaf node at index " << node_id;
444+
std::copy(&leaf_weights[leaf_id * size_leaf_vector],
445+
&leaf_weights[(leaf_id + 1) * size_leaf_vector], leafvec.begin());
446446
model_builder.LeafVector(leafvec);
447447
} else {
448448
// Scalar leaf output
@@ -487,7 +487,7 @@ bool RegTreeHandler::is_recognized_key(std::string const& key) {
487487
|| key == "categories" || key == "leaf_child_counts" || key == "left_children"
488488
|| key == "right_children" || key == "parents" || key == "split_indices"
489489
|| key == "split_type" || key == "split_conditions" || key == "default_left"
490-
|| key == "tree_param" || key == "id");
490+
|| key == "tree_param" || key == "id" || key == "leaf_weights");
491491
}
492492

493493
/******************************************************************************

src/model_loader/detail/xgboost_json/delegated_handler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ class RegTreeHandler : public OutputHandler<ParsedRegTreeParams> {
414414
std::vector<float> loss_changes;
415415
std::vector<float> sum_hessian;
416416
std::vector<float> base_weights;
417+
std::vector<float> leaf_weights;
417418
std::vector<int> left_children;
418419
std::vector<int> right_children;
419420
std::vector<int> parents;

tests/python/test_xgboost_integration.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
from .util import TemporaryDirectory, has_pandas, to_categorical
3131

3232

33+
def _get_model_filename(name, model_format):
34+
if model_format == "ubjson":
35+
model_format = "ubj"
36+
return f"{name}.{model_format}"
37+
38+
3339
def generate_data_for_squared_log_error(n_targets: int = 1):
3440
"""Generate data containing outliers."""
3541
n_rows = 4096
@@ -108,7 +114,7 @@ def test_xgb_regressor(
108114
num_boost_round=num_boost_round,
109115
)
110116
with TemporaryDirectory() as tmpdir:
111-
model_path = pathlib.Path(tmpdir) / f"model.{model_format}"
117+
model_path = pathlib.Path(tmpdir) / _get_model_filename("model", model_format)
112118
xgb_model.save_model(model_path)
113119
tl_model = treelite.frontend.load_xgboost_model(
114120
model_path, format_choice=model_format
@@ -179,7 +185,7 @@ def test_xgb_multiclass_classifier(
179185
)
180186

181187
with TemporaryDirectory() as tmpdir:
182-
model_path = pathlib.Path(tmpdir) / f"iris.{model_format}"
188+
model_path = pathlib.Path(tmpdir) / _get_model_filename("iris", model_format)
183189
xgb_model.save_model(model_path)
184190
tl_model = treelite.frontend.load_xgboost_model(
185191
model_path, format_choice=model_format
@@ -262,10 +268,7 @@ def test_xgb_nonlinear_objective(
262268
)
263269

264270
objective_tag = objective.replace(":", "_")
265-
if model_format in ["json", "ubjson"]:
266-
model_name = f"nonlinear_{objective_tag}.{model_format}"
267-
else:
268-
model_name = f"nonlinear_{objective_tag}.deprecated"
271+
model_name = _get_model_filename(f"nonlinear_{objective_tag}", model_format)
269272
with TemporaryDirectory() as tmpdir:
270273
model_path = pathlib.Path(tmpdir) / model_name
271274
xgb_model.save_model(model_path)
@@ -458,7 +461,9 @@ def test_xgb_multi_target_binary_classifier(
458461
tl_model = treelite.frontend.from_xgboost(bst)
459462
else:
460463
with TemporaryDirectory() as tmpdir:
461-
model_path = pathlib.Path(tmpdir) / f"multi_target.{model_format}"
464+
model_path = pathlib.Path(tmpdir) / _get_model_filename(
465+
"multi_target", model_format
466+
)
462467
bst.save_model(model_path)
463468
tl_model = treelite.frontend.load_xgboost_model(
464469
model_path, format_choice=model_format
@@ -533,7 +538,7 @@ def test_xgb_multi_target_regressor(
533538
)
534539

535540
with TemporaryDirectory() as tmpdir:
536-
model_path = pathlib.Path(tmpdir) / f"model.{model_format}"
541+
model_path = pathlib.Path(tmpdir) / _get_model_filename("model", model_format)
537542
xgb_model.save_model(model_path)
538543
tl_model = treelite.frontend.load_xgboost_model(
539544
model_path, format_choice=model_format
@@ -578,7 +583,7 @@ def test_xgb_detect_format(
578583
expected_pred = xgb_model.predict(xgb.DMatrix(X)).reshape((X.shape[0], 1, -1))
579584

580585
with TemporaryDirectory() as tmpdir:
581-
model_path = pathlib.Path(tmpdir) / f"model.{model_format}"
586+
model_path = pathlib.Path(tmpdir) / _get_model_filename("model", model_format)
582587
xgb_model.save_model(model_path)
583588
detected_format = treelite.frontend._detect_xgboost_format(model_path)
584589
assert detected_format == model_format

0 commit comments

Comments
 (0)