@@ -543,10 +543,8 @@ def load_data(self, data: RayDMatrix):
543543 self ._local_n [data ] = sum (len (a ) for a in param ["data" ])
544544 else :
545545 self ._local_n [data ] = len (param ["data" ])
546- data .unload_data () # Free object store
547546
548- matrix = _get_dmatrix (data , param )
549- self ._data [data ] = matrix
547+ self ._data [data ] = param
550548
551549 self ._distributed_callbacks .after_data_loading (self , data )
552550
@@ -578,19 +576,9 @@ def train(self, rabit_args: List[str], return_bst: bool,
578576 if dtrain not in self ._data :
579577 self .load_data (dtrain )
580578
581- local_dtrain = self ._data [dtrain ]
582-
583- if not local_dtrain .get_label ().size :
584- raise RuntimeError (
585- "Training data has no label set. Please make sure to set "
586- "the `label` argument when initializing `RayDMatrix()` "
587- "for data you would like to train on." )
588-
589- local_evals = []
590- for deval , name in evals :
579+ for deval , _name in evals :
591580 if deval not in self ._data :
592581 self .load_data (deval )
593- local_evals .append ((self ._data [deval ], name ))
594582
595583 evals_result = dict ()
596584
@@ -609,6 +597,21 @@ def train(self, rabit_args: List[str], return_bst: bool,
609597 def _train ():
610598 try :
611599 with _RabitContext (str (id (self )), rabit_args ):
600+
601+ local_dtrain = _get_dmatrix (dtrain , self ._data [dtrain ])
602+
603+ if not local_dtrain .get_label ().size :
604+ raise RuntimeError (
605+ "Training data has no label set. Please make sure "
606+ "to set the `label` argument when initializing "
607+ "`RayDMatrix()` for data you would like "
608+ "to train on." )
609+
610+ local_evals = []
611+ for deval , name in evals :
612+ local_evals .append ((_get_dmatrix (
613+ deval , self ._data [deval ]), name ))
614+
612615 if LEGACY_CALLBACK :
613616 for xgb_callback in kwargs .get ("callbacks" , []):
614617 if isinstance (xgb_callback , TrainingCallback ):
@@ -668,7 +671,7 @@ def predict(self, model: xgb.Booster, data: RayDMatrix, **kwargs):
668671
669672 if data not in self ._data :
670673 self .load_data (data )
671- local_data = self ._data [data ]
674+ local_data = _get_dmatrix ( data , self ._data [data ])
672675
673676 predictions = model .predict (local_data , ** kwargs )
674677 if predictions .ndim == 1 :
0 commit comments