Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,18 @@ def _check_pyarrow_for_polars() -> None:
raise ImportError("`pyarrow` is required for polars.")


def _reject_polars_categorical(data: DataType) -> None:
pl = import_polars()

for name, dtype in zip(data.columns, data.dtypes):
if isinstance(dtype, pl.Categorical):
raise ValueError(
"XGBoost does not support `polars.Categorical` because its "
"encoding can be sparse. Use `polars.Enum` instead. "
f"Invalid column: {name}",
)


def _transform_polars_df(
data: DataType,
enable_categorical: bool,
Expand All @@ -954,6 +966,7 @@ def _transform_polars_df(
df = data

_check_pyarrow_for_polars()
_reject_polars_categorical(df)
table = df.to_arrow()
return _transform_arrow_table(
table, enable_categorical, feature_names, feature_types
Expand Down
4 changes: 2 additions & 2 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2024, XGBoost Contributors
* Copyright 2017-2026, XGBoost Contributors
* \file hist_util.h
* \brief Utility for fast histogram aggregation
* \author Philip Cho, Tianqi Chen
Expand All @@ -8,7 +8,7 @@
#define XGBOOST_COMMON_HIST_UTIL_H_

#include <algorithm>
#include <cmath>
#include <cmath> // for nextafter
#include <cstdint> // for uint32_t
#include <limits>
#include <map>
Expand Down
25 changes: 12 additions & 13 deletions tests/python/test_with_polars.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Copyright 2024, XGBoost contributors"""
"""Copyright 2024-2026, XGBoost contributors"""

import json
from pathlib import Path
Expand Down Expand Up @@ -38,15 +38,16 @@ def test_polars_basic(
if isinstance(Xy, xgb.QuantileDMatrix):
# skip min values in the cut.
np.testing.assert_allclose(res[1:, :], res1[1:, :])
m = np.nextafter(np.float32(1), -np.inf) # min_val
else:
np.testing.assert_allclose(res, res1)
m = np.float32(0)

# boolean
df = pl.DataFrame({"a": [True, False, False], "b": [False, False, True]})
Xy = DMatrixT(df)
np.testing.assert_allclose(
Xy.get_data().data, np.array([1, 0, 0, 0, 0, 1]), atol=1e-5
)

np.testing.assert_equal(Xy.get_data().data, np.array([1, m, m, m, m, 1]))


def test_polars_missing() -> None:
Expand Down Expand Up @@ -140,25 +141,23 @@ def test_regressor() -> None:


def test_categorical() -> None:
import polars as pl

cats = ["aa", "cc", "bb", "ee", "ee"]
df = pl.DataFrame(
{"f0": [1, 3, 2, 4, 4], "f1": cats},
schema=[("f0", pl.Int64()), ("f1", pl.Categorical(ordering="lexical"))],
schema=[("f0", pl.Int64()), ("f1", pl.Categorical())],
)

data = xgb.DMatrix(df)
categories = data.get_categories(export_to_arrow=True)
assert dict(categories.to_arrow())["f0"] is None
f1 = dict(categories.to_arrow())["f1"]
assert f1 is not None
assert f1.to_pylist() == cats[:4]
with pytest.raises(ValueError, match="polars.Categorical.*polars.Enum"):
xgb.DMatrix(df)

df = pl.DataFrame(
{"f0": [1, 3, 2, 4, 4], "f1": cats},
schema=[("f0", pl.Int64()), ("f1", pl.Enum(cats[:4]))],
)
arr = df["f1"].to_arrow()
assert arr.dictionary.to_pylist() == cats[:4]
assert arr.indices.to_pylist() == [0, 1, 2, 3, 3]

data = xgb.DMatrix(df)
categories = data.get_categories(export_to_arrow=True)
assert dict(categories.to_arrow())["f0"] is None
Expand Down
Loading