|
30 | 30 | from .util import TemporaryDirectory, has_pandas, to_categorical |
31 | 31 |
|
32 | 32 |
|
| 33 | +def _get_model_filename(name, model_format): |
| 34 | + if model_format == "ubjson": |
| 35 | + model_format = "ubj" |
| 36 | + return f"{name}.{model_format}" |
| 37 | + |
| 38 | + |
33 | 39 | def generate_data_for_squared_log_error(n_targets: int = 1): |
34 | 40 | """Generate data containing outliers.""" |
35 | 41 | n_rows = 4096 |
@@ -108,7 +114,7 @@ def test_xgb_regressor( |
108 | 114 | num_boost_round=num_boost_round, |
109 | 115 | ) |
110 | 116 | with TemporaryDirectory() as tmpdir: |
111 | | - model_path = pathlib.Path(tmpdir) / f"model.{model_format}" |
| 117 | + model_path = pathlib.Path(tmpdir) / _get_model_filename("model", model_format) |
112 | 118 | xgb_model.save_model(model_path) |
113 | 119 | tl_model = treelite.frontend.load_xgboost_model( |
114 | 120 | model_path, format_choice=model_format |
@@ -179,7 +185,7 @@ def test_xgb_multiclass_classifier( |
179 | 185 | ) |
180 | 186 |
|
181 | 187 | with TemporaryDirectory() as tmpdir: |
182 | | - model_path = pathlib.Path(tmpdir) / f"iris.{model_format}" |
| 188 | + model_path = pathlib.Path(tmpdir) / _get_model_filename("iris", model_format) |
183 | 189 | xgb_model.save_model(model_path) |
184 | 190 | tl_model = treelite.frontend.load_xgboost_model( |
185 | 191 | model_path, format_choice=model_format |
@@ -262,10 +268,7 @@ def test_xgb_nonlinear_objective( |
262 | 268 | ) |
263 | 269 |
|
264 | 270 | objective_tag = objective.replace(":", "_") |
265 | | - if model_format in ["json", "ubjson"]: |
266 | | - model_name = f"nonlinear_{objective_tag}.{model_format}" |
267 | | - else: |
268 | | - model_name = f"nonlinear_{objective_tag}.deprecated" |
| 271 | + model_name = _get_model_filename(f"nonlinear_{objective_tag}", model_format) |
269 | 272 | with TemporaryDirectory() as tmpdir: |
270 | 273 | model_path = pathlib.Path(tmpdir) / model_name |
271 | 274 | xgb_model.save_model(model_path) |
@@ -458,7 +461,9 @@ def test_xgb_multi_target_binary_classifier( |
458 | 461 | tl_model = treelite.frontend.from_xgboost(bst) |
459 | 462 | else: |
460 | 463 | with TemporaryDirectory() as tmpdir: |
461 | | - model_path = pathlib.Path(tmpdir) / f"multi_target.{model_format}" |
| 464 | + model_path = pathlib.Path(tmpdir) / _get_model_filename( |
| 465 | + "multi_target", model_format |
| 466 | + ) |
462 | 467 | bst.save_model(model_path) |
463 | 468 | tl_model = treelite.frontend.load_xgboost_model( |
464 | 469 | model_path, format_choice=model_format |
@@ -533,7 +538,7 @@ def test_xgb_multi_target_regressor( |
533 | 538 | ) |
534 | 539 |
|
535 | 540 | with TemporaryDirectory() as tmpdir: |
536 | | - model_path = pathlib.Path(tmpdir) / f"model.{model_format}" |
| 541 | + model_path = pathlib.Path(tmpdir) / _get_model_filename("model", model_format) |
537 | 542 | xgb_model.save_model(model_path) |
538 | 543 | tl_model = treelite.frontend.load_xgboost_model( |
539 | 544 | model_path, format_choice=model_format |
@@ -578,7 +583,7 @@ def test_xgb_detect_format( |
578 | 583 | expected_pred = xgb_model.predict(xgb.DMatrix(X)).reshape((X.shape[0], 1, -1)) |
579 | 584 |
|
580 | 585 | with TemporaryDirectory() as tmpdir: |
581 | | - model_path = pathlib.Path(tmpdir) / f"model.{model_format}" |
| 586 | + model_path = pathlib.Path(tmpdir) / _get_model_filename("model", model_format) |
582 | 587 | xgb_model.save_model(model_path) |
583 | 588 | detected_format = treelite.frontend._detect_xgboost_format(model_path) |
584 | 589 | assert detected_format == model_format |
|
0 commit comments