Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 536b702

Browse files
authored
Add enable_categorical, better detection for being in a PG (#235)
XGBoost>=1.5 added a new dmatrix param, `enable_categorical`. Adding it here. This PR also makes the detection of the case of already being in a placement group more broad and not limited to just Tune. This allows for xgboost-ray to be ran in nested tasks that are not related to Tune (this should be two PRs really, but I needed one branch to patch it for hackathon). Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
1 parent 705a592 commit 536b702

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

xgboost_ray/main.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
315315
"feature_types": data.feature_types,
316316
"missing": data.missing,
317317
}
318+
319+
if data.enable_categorical is not None:
320+
dm_param["enable_categorical"] = data.enable_categorical
321+
318322
param.update(dm_param)
319323
it = RayDataIter(**param)
320324
matrix = xgb.DeviceQuantileDMatrix(it, **dm_param)
@@ -342,6 +346,9 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
342346
if "qid" not in inspect.signature(xgb.DMatrix).parameters:
343347
param.pop("qid", None)
344348

349+
if data.enable_categorical is not None:
350+
param["enable_categorical"] = data.enable_categorical
351+
345352
matrix = xgb.DMatrix(**param)
346353

347354
if not LEGACY_MATRIX:
@@ -855,25 +862,22 @@ def _create_placement_group(cpus_per_actor, gpus_per_actor,
855862

856863

857864
def _create_communication_processes(added_tune_callback: bool = False):
858-
# Create Queue and Event actors and make sure to colocate with driver node.
859-
node_ip = get_node_ip_address()
860865
# Have to explicitly set num_cpus to 0.
861866
placement_option = {"num_cpus": 0}
862-
if added_tune_callback:
863-
# If Tune is using placement groups, then we force Queue and
867+
current_pg = get_current_placement_group()
868+
if current_pg is not None:
869+
# If we are already in a placement group, let's use it
870+
# Also, if we are specifically in Tune, let's
871+
# ensure that we force Queue and
864872
# StopEvent onto same bundle as the Trainable.
865-
# This forces all 3 to be on the same node.
866-
current_pg = get_current_placement_group()
867-
if current_pg is None:
868-
# This means the user is not using Tune PGs after all -
869-
# e.g. via setting an environment variable.
870-
placement_option.update({"resources": {f"node:{node_ip}": 0.01}})
871-
else:
872-
placement_option.update({
873-
"placement_group": current_pg,
874-
"placement_group_bundle_index": 0
875-
})
873+
placement_option.update({
874+
"placement_group": current_pg,
875+
"placement_group_bundle_index": 0 if added_tune_callback else -1
876+
})
876877
else:
878+
# Create Queue and Event actors and make sure to colocate with
879+
# driver node.
880+
node_ip = get_node_ip_address()
877881
placement_option.update({"resources": {f"node:{node_ip}": 0.01}})
878882
queue = Queue(actor_options=placement_option) # Queue actor
879883
stop_event = Event(actor_options=placement_option) # Stop event actor
@@ -1327,7 +1331,7 @@ def _wrapped(*args, **kwargs):
13271331
"Please disable elastic_training in RayParams in "
13281332
"order to use xgboost_ray with Tune.")
13291333

1330-
if added_tune_callback:
1334+
if added_tune_callback or get_current_placement_group():
13311335
# Don't autodetect resources when used with Tune.
13321336
cpus_per_actor = ray_params.cpus_per_actor
13331337
gpus_per_actor = max(0, ray_params.gpus_per_actor)
@@ -1408,7 +1412,7 @@ def _wrapped(*args, **kwargs):
14081412

14091413
placement_strategy = None
14101414
if not ray_params.elastic_training:
1411-
if added_tune_callback:
1415+
if added_tune_callback or get_current_placement_group():
14121416
# Tune is using placement groups, so the strategy has already
14131417
# been set. Don't create an additional placement_group here.
14141418
placement_strategy = None

xgboost_ray/matrix.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
label_upper_bound: List[Optional[Data]],
8484
feature_names: Optional[List[str]],
8585
feature_types: Optional[List[np.dtype]],
86+
enable_categorical: Optional[bool],
8687
):
8788
super(RayDataIter, self).__init__()
8889

@@ -98,6 +99,7 @@ def __init__(
9899
self._label_upper_bound = label_upper_bound
99100
self._feature_names = feature_names
100101
self._feature_types = feature_types
102+
self.enable_categorical = enable_categorical
101103

102104
self._iter = 0
103105

@@ -129,7 +131,8 @@ def next(self, input_data: Callable):
129131
label_lower_bound=self._prop(self._label_lower_bound),
130132
label_upper_bound=self._prop(self._label_upper_bound),
131133
feature_names=self._feature_names,
132-
feature_types=self._feature_types)
134+
feature_types=self._feature_types,
135+
enable_categorical=self._enable_categorical)
133136
self._iter += 1
134137
return 1
135138

@@ -146,6 +149,7 @@ def __init__(self,
146149
feature_names: Optional[List[str]] = None,
147150
feature_types: Optional[List[np.dtype]] = None,
148151
qid: Optional[Data] = None,
152+
enable_categorical: Optional[bool] = None,
149153
filetype: Optional[RayFileType] = None,
150154
ignore: Optional[List[str]] = None,
151155
**kwargs):
@@ -159,6 +163,7 @@ def __init__(self,
159163
self.feature_names = feature_names
160164
self.feature_types = feature_types
161165
self.qid = qid
166+
self.enable_categorical = enable_categorical
162167

163168
self.data_source = None
164169
self.actor_shards = None
@@ -655,6 +660,7 @@ def __init__(self,
655660
feature_names: Optional[List[str]] = None,
656661
feature_types: Optional[List[np.dtype]] = None,
657662
qid: Optional[Data] = None,
663+
enable_categorical: Optional[bool] = None,
658664
num_actors: Optional[int] = None,
659665
filetype: Optional[RayFileType] = None,
660666
ignore: Optional[List[str]] = None,
@@ -677,6 +683,7 @@ def __init__(self,
677683
self.feature_names = feature_names
678684
self.feature_types = feature_types
679685
self.qid = qid
686+
self.enable_categorical = enable_categorical
680687
self.missing = missing
681688

682689
self.num_actors = num_actors
@@ -706,6 +713,7 @@ def __init__(self,
706713
label_upper_bound=label_upper_bound,
707714
feature_names=feature_names,
708715
feature_types=feature_types,
716+
enable_categorical=enable_categorical,
709717
filetype=filetype,
710718
ignore=ignore,
711719
qid=qid,
@@ -721,6 +729,7 @@ def __init__(self,
721729
label_upper_bound=label_upper_bound,
722730
feature_names=feature_names,
723731
feature_types=feature_types,
732+
enable_categorical=enable_categorical,
724733
filetype=filetype,
725734
ignore=ignore,
726735
qid=qid,
@@ -829,6 +838,7 @@ def __init__(self,
829838
feature_names: Optional[List[str]] = None,
830839
feature_types: Optional[List[np.dtype]] = None,
831840
qid: Optional[Data] = None,
841+
enable_categorical: Optional[bool] = None,
832842
*args,
833843
**kwargs):
834844
if cp is None:
@@ -852,6 +862,7 @@ def __init__(self,
852862
feature_names=feature_names,
853863
feature_types=feature_types,
854864
qid=qid,
865+
enable_categorical=enable_categorical,
855866
*args,
856867
**kwargs)
857868

0 commit comments

Comments
 (0)