@@ -578,9 +578,9 @@ class RayDMatrix:
578578 def __init__ (self ,
579579 data : Data ,
580580 label : Optional [Data ] = None ,
581- missing : Optional [float ] = None ,
582581 weight : Optional [Data ] = None ,
583582 base_margin : Optional [Data ] = None ,
583+ missing : Optional [float ] = None ,
584584 label_lower_bound : Optional [Data ] = None ,
585585 label_upper_bound : Optional [Data ] = None ,
586586 feature_names : Optional [List [str ]] = None ,
@@ -730,12 +730,50 @@ def __eq__(self, other):
730730class RayDeviceQuantileDMatrix (RayDMatrix ):
731731 """Currently just a thin wrapper for type detection"""
732732
733- def __init__ (self , * args , ** kwargs ):
733+ def __init__ (self ,
734+ data : Data ,
735+ label : Optional [Data ] = None ,
736+ weight : Optional [Data ] = None ,
737+ base_margin : Optional [Data ] = None ,
738+ missing : Optional [float ] = None ,
739+ label_lower_bound : Optional [Data ] = None ,
740+ label_upper_bound : Optional [Data ] = None ,
741+ feature_names : Optional [List [str ]] = None ,
742+ feature_types : Optional [List [np .dtype ]] = None ,
743+ * args ,
744+ ** kwargs ):
734745 if cp is None :
735746 raise RuntimeError (
736747 "RayDeviceQuantileDMatrix requires cupy to be installed."
737- "\n FIX THIS by installing cupy: `pip install cupy`" )
738- super (RayDeviceQuantileDMatrix , self ).__init__ (* args , ** kwargs )
748+ "\n FIX THIS by installing cupy: `pip install cupy-cudaXYZ` "
749+ "where XYZ is your local CUDA version." )
750+ if label_lower_bound or label_upper_bound :
751+ raise RuntimeError (
752+ "RayDeviceQuantileDMatrix does not support "
753+ "`label_lower_bound` and `label_upper_bound` (just as the "
754+ "xgboost.DeviceQuantileDMatrix). Please pass None instead." )
755+ super (RayDeviceQuantileDMatrix , self ).__init__ (
756+ data = data ,
757+ label = label ,
758+ weight = weight ,
759+ base_margin = base_margin ,
760+ missing = missing ,
761+ label_lower_bound = None ,
762+ label_upper_bound = None ,
763+ feature_names = feature_names ,
764+ feature_types = feature_types ,
765+ * args ,
766+ ** kwargs )
767+
768+ def get_data (
769+ self , rank : int , num_actors : Optional [int ] = None
770+ ) -> Dict [str , Union [None , pd .DataFrame , List [Optional [pd .DataFrame ]]]]:
771+ data_dict = super (RayDeviceQuantileDMatrix , self ).get_data (
772+ rank = rank , num_actors = num_actors )
773+ # Remove some dict keys here that are generated automatically
774+ data_dict .pop ("label_lower_bound" , None )
775+ data_dict .pop ("label_upper_bound" , None )
776+ return data_dict
739777
740778
741779def _can_load_distributed (source : Data ) -> bool :
0 commit comments