|
1 | | -"""Copyright 2024, XGBoost contributors""" |
| 1 | +"""Copyright 2024-2026, XGBoost contributors""" |
2 | 2 |
|
3 | 3 | import json |
4 | 4 | from pathlib import Path |
@@ -38,15 +38,16 @@ def test_polars_basic( |
38 | 38 | if isinstance(Xy, xgb.QuantileDMatrix): |
39 | 39 | # skip min values in the cut. |
40 | 40 | np.testing.assert_allclose(res[1:, :], res1[1:, :]) |
| 41 | + m = np.nextafter(np.float32(1), -np.inf) # min_val |
41 | 42 | else: |
42 | 43 | np.testing.assert_allclose(res, res1) |
| 44 | + m = np.float32(0) |
43 | 45 |
|
44 | 46 | # boolean |
45 | 47 | df = pl.DataFrame({"a": [True, False, False], "b": [False, False, True]}) |
46 | 48 | Xy = DMatrixT(df) |
47 | | - np.testing.assert_allclose( |
48 | | - Xy.get_data().data, np.array([1, 0, 0, 0, 0, 1]), atol=1e-5 |
49 | | - ) |
| 49 | + |
| 50 | + np.testing.assert_equal(Xy.get_data().data, np.array([1, m, m, m, m, 1])) |
50 | 51 |
|
51 | 52 |
|
52 | 53 | def test_polars_missing() -> None: |
@@ -140,25 +141,23 @@ def test_regressor() -> None: |
140 | 141 |
|
141 | 142 |
|
142 | 143 | def test_categorical() -> None: |
143 | | - import polars as pl |
144 | | - |
145 | 144 | cats = ["aa", "cc", "bb", "ee", "ee"] |
146 | 145 | df = pl.DataFrame( |
147 | 146 | {"f0": [1, 3, 2, 4, 4], "f1": cats}, |
148 | | - schema=[("f0", pl.Int64()), ("f1", pl.Categorical(ordering="lexical"))], |
| 147 | + schema=[("f0", pl.Int64()), ("f1", pl.Categorical())], |
149 | 148 | ) |
150 | 149 |
|
151 | | - data = xgb.DMatrix(df) |
152 | | - categories = data.get_categories(export_to_arrow=True) |
153 | | - assert dict(categories.to_arrow())["f0"] is None |
154 | | - f1 = dict(categories.to_arrow())["f1"] |
155 | | - assert f1 is not None |
156 | | - assert f1.to_pylist() == cats[:4] |
| 150 | + with pytest.raises(ValueError, match="polars.Categorical.*polars.Enum"): |
| 151 | + xgb.DMatrix(df) |
157 | 152 |
|
158 | 153 | df = pl.DataFrame( |
159 | 154 | {"f0": [1, 3, 2, 4, 4], "f1": cats}, |
160 | 155 | schema=[("f0", pl.Int64()), ("f1", pl.Enum(cats[:4]))], |
161 | 156 | ) |
| 157 | + arr = df["f1"].to_arrow() |
| 158 | + assert arr.dictionary.to_pylist() == cats[:4] |
| 159 | + assert arr.indices.to_pylist() == [0, 1, 2, 3, 3] |
| 160 | + |
162 | 161 | data = xgb.DMatrix(df) |
163 | 162 | categories = data.get_categories(export_to_arrow=True) |
164 | 163 | assert dict(categories.to_arrow())["f0"] is None |
|
0 commit comments