Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 648d6dd

Browse files
author
Zhi Lin
authored
Add QuantileDMatrix support (#279)
* add quantile matrix Signed-off-by: Zhi Lin <zhi.lin@intel.com> * revert unrelated changes Signed-off-by: Zhi Lin <zhi.lin@intel.com> * add helper function Signed-off-by: Zhi Lin <zhi.lin@intel.com> * format Signed-off-by: Zhi Lin <zhi.lin@intel.com> --------- Signed-off-by: Zhi Lin <zhi.lin@intel.com>
1 parent 3a3123e commit 648d6dd

File tree

2 files changed

+35
-14
lines changed

2 files changed

+35
-14
lines changed

xgboost_ray/main.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def inner_f(*args, **kwargs):
7575

7676
from xgboost_ray.matrix import RayDMatrix, combine_data, \
7777
RayDeviceQuantileDMatrix, RayDataIter, concat_dataframes, \
78-
LEGACY_MATRIX
78+
LEGACY_MATRIX, QUANTILE_AVAILABLE, RayQuantileDMatrix
7979
from xgboost_ray.session import init_session, put_queue, \
8080
set_session_queue, get_rabit_rank
8181

@@ -320,7 +320,28 @@ def _set_omp_num_threads():
320320
return int(float(os.environ.get("OMP_NUM_THREADS", "0.0")))
321321

322322

323+
def _prepare_dmatrix_params(param: Dict) -> Dict:
324+
dm_param = {
325+
"data": concat_dataframes(param["data"]),
326+
"label": concat_dataframes(param["label"]),
327+
"weight": concat_dataframes(param["weight"]),
328+
"feature_weights": concat_dataframes(param["feature_weights"]),
329+
"qid": concat_dataframes(param["qid"]),
330+
"base_margin": concat_dataframes(param["base_margin"]),
331+
"label_lower_bound": concat_dataframes(param["label_lower_bound"]),
332+
"label_upper_bound": concat_dataframes(param["label_upper_bound"]),
333+
}
334+
return dm_param
335+
336+
323337
def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
338+
if QUANTILE_AVAILABLE and isinstance(data, RayQuantileDMatrix):
339+
if isinstance(param["data"], list):
340+
qdm_param = _prepare_dmatrix_params(param)
341+
param.update(qdm_param)
342+
if data.enable_categorical is not None:
343+
param["enable_categorical"] = data.enable_categorical
344+
matrix = xgb.QuantileDMatrix(**param)
324345
if not LEGACY_MATRIX and isinstance(data, RayDeviceQuantileDMatrix):
325346
# If we only got a single data shard, create a list so we can
326347
# iterate over it
@@ -355,18 +376,7 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
355376
matrix = xgb.DeviceQuantileDMatrix(it, **dm_param)
356377
else:
357378
if isinstance(param["data"], list):
358-
dm_param = {
359-
"data": concat_dataframes(param["data"]),
360-
"label": concat_dataframes(param["label"]),
361-
"weight": concat_dataframes(param["weight"]),
362-
"feature_weights": concat_dataframes(param["feature_weights"]),
363-
"qid": concat_dataframes(param["qid"]),
364-
"base_margin": concat_dataframes(param["base_margin"]),
365-
"label_lower_bound": concat_dataframes(
366-
param["label_lower_bound"]),
367-
"label_upper_bound": concat_dataframes(
368-
param["label_upper_bound"]),
369-
}
379+
dm_param = _prepare_dmatrix_params(param)
370380
param.update(dm_param)
371381

372382
ll = param.pop("label_lower_bound", None)
@@ -669,7 +679,6 @@ def _train():
669679
for deval, name in evals:
670680
local_evals.append((_get_dmatrix(
671681
deval, self._data[deval]), name))
672-
673682
if LEGACY_CALLBACK:
674683
for xgb_callback in kwargs.get("callbacks", []):
675684
if isinstance(xgb_callback, TrainingCallback):

xgboost_ray/matrix.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ class RayDataset:
3737
DataIter = object
3838
LEGACY_MATRIX = True
3939

40+
try:
41+
from xgboost.core import QuantileDmatrix
42+
QUANTILE_AVAILABLE = True
43+
except ImportError:
44+
QuantileDmatrix = object
45+
QUANTILE_AVAILABLE = False
46+
4047
if TYPE_CHECKING:
4148
from xgboost_ray.xgb import xgboost as xgb
4249

@@ -875,6 +882,11 @@ def __eq__(self, other):
875882
return self.__hash__() == other.__hash__()
876883

877884

885+
class RayQuantileDMatrix(RayDMatrix):
886+
"""Currently just a thin wrapper for type detection"""
887+
pass
888+
889+
878890
class RayDeviceQuantileDMatrix(RayDMatrix):
879891
"""Currently just a thin wrapper for type detection"""
880892

0 commit comments

Comments
 (0)