Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 1ee4a18

Browse files
authored
Initialize DMatrix in Rabit context (#179)
1 parent c723355 commit 1ee4a18

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

xgboost_ray/main.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,8 @@ def load_data(self, data: RayDMatrix):
543543
self._local_n[data] = sum(len(a) for a in param["data"])
544544
else:
545545
self._local_n[data] = len(param["data"])
546-
data.unload_data() # Free object store
547546

548-
matrix = _get_dmatrix(data, param)
549-
self._data[data] = matrix
547+
self._data[data] = param
550548

551549
self._distributed_callbacks.after_data_loading(self, data)
552550

@@ -578,19 +576,9 @@ def train(self, rabit_args: List[str], return_bst: bool,
578576
if dtrain not in self._data:
579577
self.load_data(dtrain)
580578

581-
local_dtrain = self._data[dtrain]
582-
583-
if not local_dtrain.get_label().size:
584-
raise RuntimeError(
585-
"Training data has no label set. Please make sure to set "
586-
"the `label` argument when initializing `RayDMatrix()` "
587-
"for data you would like to train on.")
588-
589-
local_evals = []
590-
for deval, name in evals:
579+
for deval, _name in evals:
591580
if deval not in self._data:
592581
self.load_data(deval)
593-
local_evals.append((self._data[deval], name))
594582

595583
evals_result = dict()
596584

@@ -609,6 +597,21 @@ def train(self, rabit_args: List[str], return_bst: bool,
609597
def _train():
610598
try:
611599
with _RabitContext(str(id(self)), rabit_args):
600+
601+
local_dtrain = _get_dmatrix(dtrain, self._data[dtrain])
602+
603+
if not local_dtrain.get_label().size:
604+
raise RuntimeError(
605+
"Training data has no label set. Please make sure "
606+
"to set the `label` argument when initializing "
607+
"`RayDMatrix()` for data you would like "
608+
"to train on.")
609+
610+
local_evals = []
611+
for deval, name in evals:
612+
local_evals.append((_get_dmatrix(
613+
deval, self._data[deval]), name))
614+
612615
if LEGACY_CALLBACK:
613616
for xgb_callback in kwargs.get("callbacks", []):
614617
if isinstance(xgb_callback, TrainingCallback):
@@ -668,7 +671,7 @@ def predict(self, model: xgb.Booster, data: RayDMatrix, **kwargs):
668671

669672
if data not in self._data:
670673
self.load_data(data)
671-
local_data = self._data[data]
674+
local_data = _get_dmatrix(data, self._data[data])
672675

673676
predictions = model.predict(local_data, **kwargs)
674677
if predictions.ndim == 1:

0 commit comments

Comments
 (0)