@@ -102,6 +102,7 @@ def __init__(
102102 label : List [Optional [Data ]],
103103 missing : Optional [float ],
104104 weight : List [Optional [Data ]],
105+ feature_weights : List [Optional [Data ]],
105106 qid : List [Optional [Data ]],
106107 base_margin : List [Optional [Data ]],
107108 label_lower_bound : List [Optional [Data ]],
@@ -118,6 +119,7 @@ def __init__(
118119 self ._label = label
119120 self ._missing = missing
120121 self ._weight = weight
122+ self ._feature_weights = feature_weights
121123 self ._qid = qid
122124 self ._base_margin = base_margin
123125 self ._label_lower_bound = label_lower_bound
@@ -151,6 +153,7 @@ def next(self, input_data: Callable):
151153 data = self ._prop (self ._data ),
152154 label = self ._prop (self ._label ),
153155 weight = self ._prop (self ._weight ),
156+ feature_weights = self ._prop (self ._feature_weights ),
154157 qid = self ._prop (self ._qid ),
155158 group = None ,
156159 label_lower_bound = self ._prop (self ._label_lower_bound ),
@@ -168,6 +171,7 @@ def __init__(self,
168171 label : Optional [Data ] = None ,
169172 missing : Optional [float ] = None ,
170173 weight : Optional [Data ] = None ,
174+ feature_weights : Optional [Data ] = None ,
171175 base_margin : Optional [Data ] = None ,
172176 label_lower_bound : Optional [Data ] = None ,
173177 label_upper_bound : Optional [Data ] = None ,
@@ -182,6 +186,7 @@ def __init__(self,
182186 self .label = label
183187 self .missing = missing
184188 self .weight = weight
189+ self .feature_weights = feature_weights
185190 self .base_margin = base_margin
186191 self .label_lower_bound = label_lower_bound
187192 self .label_upper_bound = label_upper_bound
@@ -248,8 +253,8 @@ def _split_dataframe(
248253 """
249254 Split dataframe into
250255
251- `features`, `labels`, `weight`, `base_margin `, `label_lower_bound `,
252- `label_upper_bound`
256+ `features`, `labels`, `weight`, `feature_weights `, `base_margin `,
257+ `label_lower_bound`, ` label_upper_bound`
253258
254259 """
255260 # sort dataframe by qid if exists (required by DMatrix)
@@ -268,6 +273,11 @@ def _split_dataframe(
268273 if exclude :
269274 exclude_cols .add (exclude )
270275
276+ feature_weights , exclude = data_source .get_column (
277+ local_data , self .feature_weights )
278+ if exclude :
279+ exclude_cols .add (exclude )
280+
271281 qid , exclude = data_source .get_column (local_data , self .qid )
272282 if exclude :
273283 exclude_cols .add (exclude )
@@ -291,8 +301,8 @@ def _split_dataframe(
291301 if exclude_cols :
292302 x = x [[col for col in x .columns if col not in exclude_cols ]]
293303
294- return x , label , weight , base_margin , label_lower_bound , \
295- label_upper_bound , qid
304+ return x , label , weight , feature_weights , base_margin , \
305+ label_lower_bound , label_upper_bound , qid
296306
297307 def load_data (self ,
298308 num_actors : int ,
@@ -380,7 +390,7 @@ def load_data(self,
380390 # yet. Instead, we'll be selecting the rows below.
381391 local_df = data_source .load_data (
382392 self .data , ignore = self .ignore , indices = None , ** self .kwargs )
383- x , y , w , b , ll , lu , qid = self ._split_dataframe (
393+ x , y , w , fw , b , ll , lu , qid = self ._split_dataframe (
384394 local_df , data_source = data_source )
385395
386396 if isinstance (x , list ):
@@ -396,6 +406,7 @@ def load_data(self,
396406 "data" : ray .put (x .iloc [indices ]),
397407 "label" : ray .put (y .iloc [indices ] if y is not None else None ),
398408 "weight" : ray .put (w .iloc [indices ] if w is not None else None ),
409+ "feature_weights" : ray .put (fw ),
399410 "base_margin" : ray .put (b .iloc [indices ]
400411 if b is not None else None ),
401412 "label_lower_bound" : ray .put (ll .iloc [indices ]
@@ -545,7 +556,7 @@ def load_data(self,
545556 indices = rank_shards ,
546557 ignore = self .ignore ,
547558 ** self .kwargs )
548- x , y , w , b , ll , lu , qid = self ._split_dataframe (
559+ x , y , w , fw , b , ll , lu , qid = self ._split_dataframe (
549560 local_df , data_source = data_source )
550561
551562 if isinstance (x , list ):
@@ -557,16 +568,16 @@ def load_data(self,
557568 indices = _get_sharding_indices (sharding , rank , num_actors , n )
558569
559570 if not indices :
560- x , y , w , b , ll , lu , qid = (None , None , None , None , None , None ,
561- None )
571+ x , y , w , fw , b , ll , lu , qid = (None , None , None , None , None ,
572+ None , None , None )
562573 n = 0
563574 else :
564575 local_df = data_source .load_data (
565576 self .data ,
566577 ignore = self .ignore ,
567578 indices = indices ,
568579 ** self .kwargs )
569- x , y , w , b , ll , lu , qid = self ._split_dataframe (
580+ x , y , w , fw , b , ll , lu , qid = self ._split_dataframe (
570581 local_df , data_source = data_source )
571582
572583 if isinstance (x , list ):
@@ -579,6 +590,7 @@ def load_data(self,
579590 "data" : ray .put (x ),
580591 "label" : ray .put (y ),
581592 "weight" : ray .put (w ),
593+ "feature_weights" : ray .put (fw ),
582594 "base_margin" : ray .put (b ),
583595 "label_lower_bound" : ray .put (ll ),
584596 "label_upper_bound" : ray .put (lu ),
@@ -684,6 +696,7 @@ def __init__(self,
684696 data : Data ,
685697 label : Optional [Data ] = None ,
686698 weight : Optional [Data ] = None ,
699+ feature_weights : Optional [Data ] = None ,
687700 base_margin : Optional [Data ] = None ,
688701 missing : Optional [float ] = None ,
689702 label_lower_bound : Optional [Data ] = None ,
@@ -739,6 +752,7 @@ def __init__(self,
739752 label = label ,
740753 missing = missing ,
741754 weight = weight ,
755+ feature_weights = feature_weights ,
742756 base_margin = base_margin ,
743757 label_lower_bound = label_lower_bound ,
744758 label_upper_bound = label_upper_bound ,
@@ -755,6 +769,7 @@ def __init__(self,
755769 label = label ,
756770 missing = missing ,
757771 weight = weight ,
772+ feature_weights = feature_weights ,
758773 base_margin = base_margin ,
759774 label_lower_bound = label_lower_bound ,
760775 label_upper_bound = label_upper_bound ,
0 commit comments