@@ -75,7 +75,7 @@ def inner_f(*args, **kwargs):
7575
7676from xgboost_ray .matrix import RayDMatrix , combine_data , \
7777 RayDeviceQuantileDMatrix , RayDataIter , concat_dataframes , \
78- LEGACY_MATRIX
78+ LEGACY_MATRIX , QUANTILE_AVAILABLE , RayQuantileDMatrix
7979from xgboost_ray .session import init_session , put_queue , \
8080 set_session_queue , get_rabit_rank
8181
@@ -320,7 +320,28 @@ def _set_omp_num_threads():
320320 return int (float (os .environ .get ("OMP_NUM_THREADS" , "0.0" )))
321321
322322
323+ def _prepare_dmatrix_params (param : Dict ) -> Dict :
324+ dm_param = {
325+ "data" : concat_dataframes (param ["data" ]),
326+ "label" : concat_dataframes (param ["label" ]),
327+ "weight" : concat_dataframes (param ["weight" ]),
328+ "feature_weights" : concat_dataframes (param ["feature_weights" ]),
329+ "qid" : concat_dataframes (param ["qid" ]),
330+ "base_margin" : concat_dataframes (param ["base_margin" ]),
331+ "label_lower_bound" : concat_dataframes (param ["label_lower_bound" ]),
332+ "label_upper_bound" : concat_dataframes (param ["label_upper_bound" ]),
333+ }
334+ return dm_param
335+
336+
323337def _get_dmatrix (data : RayDMatrix , param : Dict ) -> xgb .DMatrix :
338+ if QUANTILE_AVAILABLE and isinstance (data , RayQuantileDMatrix ):
339+ if isinstance (param ["data" ], list ):
340+ qdm_param = _prepare_dmatrix_params (param )
341+ param .update (qdm_param )
342+ if data .enable_categorical is not None :
343+ param ["enable_categorical" ] = data .enable_categorical
344+ matrix = xgb .QuantileDMatrix (** param )
324345 if not LEGACY_MATRIX and isinstance (data , RayDeviceQuantileDMatrix ):
325346 # If we only got a single data shard, create a list so we can
326347 # iterate over it
@@ -355,18 +376,7 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
355376 matrix = xgb .DeviceQuantileDMatrix (it , ** dm_param )
356377 else :
357378 if isinstance (param ["data" ], list ):
358- dm_param = {
359- "data" : concat_dataframes (param ["data" ]),
360- "label" : concat_dataframes (param ["label" ]),
361- "weight" : concat_dataframes (param ["weight" ]),
362- "feature_weights" : concat_dataframes (param ["feature_weights" ]),
363- "qid" : concat_dataframes (param ["qid" ]),
364- "base_margin" : concat_dataframes (param ["base_margin" ]),
365- "label_lower_bound" : concat_dataframes (
366- param ["label_lower_bound" ]),
367- "label_upper_bound" : concat_dataframes (
368- param ["label_upper_bound" ]),
369- }
379+ dm_param = _prepare_dmatrix_params (param )
370380 param .update (dm_param )
371381
372382 ll = param .pop ("label_lower_bound" , None )
@@ -669,7 +679,6 @@ def _train():
669679 for deval , name in evals :
670680 local_evals .append ((_get_dmatrix (
671681 deval , self ._data [deval ]), name ))
672-
673682 if LEGACY_CALLBACK :
674683 for xgb_callback in kwargs .get ("callbacks" , []):
675684 if isinstance (xgb_callback , TrainingCallback ):
0 commit comments