Skip to content

Commit dbac88c

Browse files
authored
Require Enum from polars. (#12240)
1 parent 939a5ff commit dbac88c

3 files changed

Lines changed: 27 additions & 15 deletions

File tree

python-package/xgboost/data.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,18 @@ def _check_pyarrow_for_polars() -> None:
937937
raise ImportError("`pyarrow` is required for polars.")
938938

939939

940+
def _reject_polars_categorical(data: DataType) -> None:
941+
pl = import_polars()
942+
943+
for name, dtype in zip(data.columns, data.dtypes):
944+
if isinstance(dtype, pl.Categorical):
945+
raise ValueError(
946+
"XGBoost does not support `polars.Categorical` because its "
947+
"encoding can be sparse. Use `polars.Enum` instead. "
948+
f"Invalid column: {name}",
949+
)
950+
951+
940952
def _transform_polars_df(
941953
data: DataType,
942954
enable_categorical: bool,
@@ -954,6 +966,7 @@ def _transform_polars_df(
954966
df = data
955967

956968
_check_pyarrow_for_polars()
969+
_reject_polars_categorical(df)
957970
table = df.to_arrow()
958971
return _transform_arrow_table(
959972
table, enable_categorical, feature_names, feature_types

src/common/hist_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2017-2024, XGBoost Contributors
2+
* Copyright 2017-2026, XGBoost Contributors
33
* \file hist_util.h
44
* \brief Utility for fast histogram aggregation
55
* \author Philip Cho, Tianqi Chen
@@ -8,7 +8,7 @@
88
#define XGBOOST_COMMON_HIST_UTIL_H_
99

1010
#include <algorithm>
11-
#include <cmath>
11+
#include <cmath> // for nextafter
1212
#include <cstdint> // for uint32_t
1313
#include <limits>
1414
#include <map>

tests/python/test_with_polars.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Copyright 2024, XGBoost contributors"""
1+
"""Copyright 2024-2026, XGBoost contributors"""
22

33
import json
44
from pathlib import Path
@@ -38,15 +38,16 @@ def test_polars_basic(
3838
if isinstance(Xy, xgb.QuantileDMatrix):
3939
# skip min values in the cut.
4040
np.testing.assert_allclose(res[1:, :], res1[1:, :])
41+
m = np.nextafter(np.float32(1), -np.inf) # min_val
4142
else:
4243
np.testing.assert_allclose(res, res1)
44+
m = np.float32(0)
4345

4446
# boolean
4547
df = pl.DataFrame({"a": [True, False, False], "b": [False, False, True]})
4648
Xy = DMatrixT(df)
47-
np.testing.assert_allclose(
48-
Xy.get_data().data, np.array([1, 0, 0, 0, 0, 1]), atol=1e-5
49-
)
49+
50+
np.testing.assert_equal(Xy.get_data().data, np.array([1, m, m, m, m, 1]))
5051

5152

5253
def test_polars_missing() -> None:
@@ -140,25 +141,23 @@ def test_regressor() -> None:
140141

141142

142143
def test_categorical() -> None:
143-
import polars as pl
144-
145144
cats = ["aa", "cc", "bb", "ee", "ee"]
146145
df = pl.DataFrame(
147146
{"f0": [1, 3, 2, 4, 4], "f1": cats},
148-
schema=[("f0", pl.Int64()), ("f1", pl.Categorical(ordering="lexical"))],
147+
schema=[("f0", pl.Int64()), ("f1", pl.Categorical())],
149148
)
150149

151-
data = xgb.DMatrix(df)
152-
categories = data.get_categories(export_to_arrow=True)
153-
assert dict(categories.to_arrow())["f0"] is None
154-
f1 = dict(categories.to_arrow())["f1"]
155-
assert f1 is not None
156-
assert f1.to_pylist() == cats[:4]
150+
with pytest.raises(ValueError, match="polars.Categorical.*polars.Enum"):
151+
xgb.DMatrix(df)
157152

158153
df = pl.DataFrame(
159154
{"f0": [1, 3, 2, 4, 4], "f1": cats},
160155
schema=[("f0", pl.Int64()), ("f1", pl.Enum(cats[:4]))],
161156
)
157+
arr = df["f1"].to_arrow()
158+
assert arr.dictionary.to_pylist() == cats[:4]
159+
assert arr.indices.to_pylist() == [0, 1, 2, 3, 3]
160+
162161
data = xgb.DMatrix(df)
163162
categories = data.get_categories(export_to_arrow=True)
164163
assert dict(categories.to_arrow())["f0"] is None

0 commit comments

Comments
 (0)