Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 147d15c

Browse files
authored
Fix RayDeviceQuantileDMatrix (#69)
This fixes the RayDeviceQuantileDMatrix usage for the new data source interface. Also gets rid of Ray core calls to private APIs.
1 parent 61d293e commit 147d15c

File tree

2 files changed

+67
-18
lines changed

2 files changed

+67
-18
lines changed

xgboost_ray/main.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,19 +209,29 @@ def _set_omp_num_threads():
209209

210210
def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
211211
if isinstance(data, RayDeviceQuantileDMatrix):
212-
if isinstance(param["data"], list):
213-
dm_param = {
214-
"feature_names": data.feature_names,
215-
"feature_types": data.feature_types,
216-
"missing": data.missing,
217-
}
218-
if not isinstance(data, xgb.DeviceQuantileDMatrix):
219-
pass
220-
param.update(dm_param)
221-
it = RayDataIter(**param)
222-
matrix = xgb.DeviceQuantileDMatrix(it, **dm_param)
223-
else:
224-
matrix = xgb.DeviceQuantileDMatrix(**param)
212+
# If we only got a single data shard, create a list so we can
213+
# iterate over it
214+
if not isinstance(param["data"], list):
215+
param["data"] = [param["data"]]
216+
217+
if not isinstance(param["label"], list):
218+
param["label"] = [param["label"]]
219+
if not isinstance(param["weight"], list):
220+
param["weight"] = [param["weight"]]
221+
if not isinstance(param["data"], list):
222+
param["base_margin"] = [param["base_margin"]]
223+
224+
param["label_lower_bound"] = [None]
225+
param["label_upper_bound"] = [None]
226+
227+
dm_param = {
228+
"feature_names": data.feature_names,
229+
"feature_types": data.feature_types,
230+
"missing": data.missing,
231+
}
232+
param.update(dm_param)
233+
it = RayDataIter(**param)
234+
matrix = xgb.DeviceQuantileDMatrix(it, **dm_param)
225235
else:
226236
if isinstance(param["data"], list):
227237
dm_param = {
@@ -436,7 +446,8 @@ def train(self, rabit_args: List[str], params: Dict[str, Any],
436446
if num_threads > 0:
437447
local_params["num_threads"] = num_threads
438448
else:
439-
local_params["nthread"] = ray.utils.get_num_cpus()
449+
local_params["nthread"] = sum(
450+
num for _, num in ray.get_resource_ids().get("CPU", []))
440451

441452
if dtrain not in self._data:
442453
self.load_data(dtrain)

xgboost_ray/matrix.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,9 @@ class RayDMatrix:
578578
def __init__(self,
579579
data: Data,
580580
label: Optional[Data] = None,
581-
missing: Optional[float] = None,
582581
weight: Optional[Data] = None,
583582
base_margin: Optional[Data] = None,
583+
missing: Optional[float] = None,
584584
label_lower_bound: Optional[Data] = None,
585585
label_upper_bound: Optional[Data] = None,
586586
feature_names: Optional[List[str]] = None,
@@ -730,12 +730,50 @@ def __eq__(self, other):
730730
class RayDeviceQuantileDMatrix(RayDMatrix):
731731
"""Currently just a thin wrapper for type detection"""
732732

733-
def __init__(self, *args, **kwargs):
733+
def __init__(self,
734+
data: Data,
735+
label: Optional[Data] = None,
736+
weight: Optional[Data] = None,
737+
base_margin: Optional[Data] = None,
738+
missing: Optional[float] = None,
739+
label_lower_bound: Optional[Data] = None,
740+
label_upper_bound: Optional[Data] = None,
741+
feature_names: Optional[List[str]] = None,
742+
feature_types: Optional[List[np.dtype]] = None,
743+
*args,
744+
**kwargs):
734745
if cp is None:
735746
raise RuntimeError(
736747
"RayDeviceQuantileDMatrix requires cupy to be installed."
737-
"\nFIX THIS by installing cupy: `pip install cupy`")
738-
super(RayDeviceQuantileDMatrix, self).__init__(*args, **kwargs)
748+
"\nFIX THIS by installing cupy: `pip install cupy-cudaXYZ` "
749+
"where XYZ is your local CUDA version.")
750+
if label_lower_bound or label_upper_bound:
751+
raise RuntimeError(
752+
"RayDeviceQuantileDMatrix does not support "
753+
"`label_lower_bound` and `label_upper_bound` (just as the "
754+
"xgboost.DeviceQuantileDMatrix). Please pass None instead.")
755+
super(RayDeviceQuantileDMatrix, self).__init__(
756+
data=data,
757+
label=label,
758+
weight=weight,
759+
base_margin=base_margin,
760+
missing=missing,
761+
label_lower_bound=None,
762+
label_upper_bound=None,
763+
feature_names=feature_names,
764+
feature_types=feature_types,
765+
*args,
766+
**kwargs)
767+
768+
def get_data(
769+
self, rank: int, num_actors: Optional[int] = None
770+
) -> Dict[str, Union[None, pd.DataFrame, List[Optional[pd.DataFrame]]]]:
771+
data_dict = super(RayDeviceQuantileDMatrix, self).get_data(
772+
rank=rank, num_actors=num_actors)
773+
# Remove some dict keys here that are generated automatically
774+
data_dict.pop("label_lower_bound", None)
775+
data_dict.pop("label_upper_bound", None)
776+
return data_dict
739777

740778

741779
def _can_load_distributed(source: Data) -> bool:

0 commit comments

Comments
 (0)