Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit e7867d9

Browse files
authored
Check if label is set for training/evaluation data (#64)
1 parent 5d65f24 commit e7867d9

File tree

2 files changed

+39
-12
lines changed

2 files changed

+39
-12
lines changed

xgboost_ray/main.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,12 @@ def train(self, rabit_args: List[str], params: Dict[str, Any],
439439

440440
local_dtrain = self._data[dtrain]
441441

442+
if not local_dtrain.get_label().size:
443+
raise RuntimeError(
444+
"Training data has no label set. Please make sure to set "
445+
"the `label` argument when initializing `RayDMatrix()` "
446+
"for data you would like to train on.")
447+
442448
local_evals = []
443449
for deval, name in evals:
444450
if deval not in self._data:
@@ -926,16 +932,18 @@ def handle_actor_failure(actor_id):
926932
return bst, evals_result, _training_state.additional_results
927933

928934

929-
def train(params: Dict,
930-
dtrain: RayDMatrix,
931-
num_boost_round: int = 10,
932-
*args,
933-
evals=(),
934-
evals_result: Optional[Dict] = None,
935-
additional_results: Optional[Dict] = None,
936-
ray_params: Union[None, RayParams, Dict] = None,
937-
_remote: Optional[bool] = None,
938-
**kwargs) -> xgb.Booster:
935+
def train(
936+
params: Dict,
937+
dtrain: RayDMatrix,
938+
num_boost_round: int = 10,
939+
*args,
940+
evals: Union[List[Tuple[RayDMatrix, str]], Tuple[RayDMatrix, str]] = (
941+
),
942+
evals_result: Optional[Dict] = None,
943+
additional_results: Optional[Dict] = None,
944+
ray_params: Union[None, RayParams, Dict] = None,
945+
_remote: Optional[bool] = None,
946+
**kwargs) -> xgb.Booster:
939947
"""Distributed XGBoost training via Ray.
940948
941949
This function will connect to a Ray cluster, create ``num_actors``
@@ -970,8 +978,8 @@ def train(params: Dict,
970978
Args:
971979
params (Dict): parameter dict passed to ``xgboost.train()``
972980
dtrain (RayDMatrix): Data object containing the training data.
973-
evals (Union[List[Tuple], Tuple]): ``evals`` tuple passed to
974-
``xgboost.train()``.
981+
evals (Union[List[Tuple[RayDMatrix, str]], Tuple[RayDMatrix, str]]):
982+
``evals`` tuple passed to ``xgboost.train()``.
975983
evals_result (Optional[Dict]): Dict to store evaluation results in.
976984
additional_results (Optional[Dict]): Dict to store additional results.
977985
ray_params (Union[None, RayParams, Dict]): Parameters to configure
@@ -1074,9 +1082,21 @@ def _wrapped(*args, **kwargs):
10741082
"effectively disabled. Please set `RayParams.max_actor_restarts` "
10751083
"to something larger than 0 to enable elastic training.")
10761084

1085+
if not dtrain.has_label:
1086+
raise ValueError(
1087+
"Training data has no label set. Please make sure to set "
1088+
"the `label` argument when initializing `RayDMatrix()` "
1089+
"for data you would like to train on.")
1090+
10771091
if not dtrain.loaded and not dtrain.distributed:
10781092
dtrain.load_data(ray_params.num_actors)
1093+
10791094
for (deval, name) in evals:
1095+
if not deval.has_label:
1096+
raise ValueError(
1097+
"Evaluation data has no label set. Please make sure to set "
1098+
"the `label` argument when initializing `RayDMatrix()` "
1099+
"for data you would like to evaluate on.")
10801100
if not deval.loaded and not deval.distributed:
10811101
deval.load_data(ray_params.num_actors)
10821102

xgboost_ray/matrix.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,9 @@ def _load_data_petastorm(self, data: Sequence[str]):
348348

349349
local_df = pd.concat(shards, copy=False)
350350

351+
if self.ignore:
352+
local_df = local_df[local_df.columns.difference(self.ignore)]
353+
351354
x, y, w, b, ll, lu = self._split_dataframe(local_df)
352355
return x, y, w, b, ll, lu
353356

@@ -706,6 +709,10 @@ def __init__(self,
706709
if not distributed and num_actors is not None and not lazy:
707710
self.load_data(num_actors)
708711

712+
@property
713+
def has_label(self):
714+
return self.loader.label is not None
715+
709716
def load_data(self,
710717
num_actors: Optional[int] = None,
711718
rank: Optional[int] = None):

0 commit comments

Comments
 (0)