Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit dcdc4b7

Browse files
authored
Add feature weights support (#265)
* added support for feature_weights parameter and associated tests * added tests for the feature_weights parameter * ran code formatter * revert version * address lint_test failures * addressed further test_lint failures * addressing pytest failure reduce resource request * addressing pytest failure increase boosting rounds for convergence * addressing pytest failure increase boosting rounds for convergence
1 parent b8a7557 commit dcdc4b7

File tree

5 files changed

+84
-11
lines changed

5 files changed

+84
-11
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,7 @@ project-id
9090
# Downloaded test data
9191
*.csv
9292
*.csv.gz
93-
*.parquet
93+
*.parquet
94+
95+
# Byte-compiled files
96+
__pycache__/

xgboost_ray/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
328328
param["label"] = [param["label"]]
329329
if not isinstance(param["weight"], list):
330330
param["weight"] = [param["weight"]]
331+
if not isinstance(param["feature_weights"], list):
332+
param["feature_weights"] = [param["feature_weights"]]
331333
if not isinstance(param["qid"], list):
332334
param["qid"] = [param["qid"]]
333335
if not isinstance(param["data"], list):
@@ -354,6 +356,7 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
354356
"data": concat_dataframes(param["data"]),
355357
"label": concat_dataframes(param["label"]),
356358
"weight": concat_dataframes(param["weight"]),
359+
"feature_weights": concat_dataframes(param["feature_weights"]),
357360
"qid": concat_dataframes(param["qid"]),
358361
"base_margin": concat_dataframes(param["base_margin"]),
359362
"label_lower_bound": concat_dataframes(
@@ -365,6 +368,7 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
365368

366369
ll = param.pop("label_lower_bound", None)
367370
lu = param.pop("label_upper_bound", None)
371+
fw = param.pop("feature_weights", None)
368372

369373
if LEGACY_MATRIX:
370374
param.pop("base_margin", None)
@@ -378,7 +382,8 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix:
378382
matrix = xgb.DMatrix(**param)
379383

380384
if not LEGACY_MATRIX:
381-
matrix.set_info(label_lower_bound=ll, label_upper_bound=lu)
385+
matrix.set_info(
386+
label_lower_bound=ll, label_upper_bound=lu, feature_weights=fw)
382387

383388
data.update_matrix_properties(matrix)
384389
return matrix

xgboost_ray/matrix.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
label: List[Optional[Data]],
103103
missing: Optional[float],
104104
weight: List[Optional[Data]],
105+
feature_weights: List[Optional[Data]],
105106
qid: List[Optional[Data]],
106107
base_margin: List[Optional[Data]],
107108
label_lower_bound: List[Optional[Data]],
@@ -118,6 +119,7 @@ def __init__(
118119
self._label = label
119120
self._missing = missing
120121
self._weight = weight
122+
self._feature_weights = feature_weights
121123
self._qid = qid
122124
self._base_margin = base_margin
123125
self._label_lower_bound = label_lower_bound
@@ -151,6 +153,7 @@ def next(self, input_data: Callable):
151153
data=self._prop(self._data),
152154
label=self._prop(self._label),
153155
weight=self._prop(self._weight),
156+
feature_weights=self._prop(self._feature_weights),
154157
qid=self._prop(self._qid),
155158
group=None,
156159
label_lower_bound=self._prop(self._label_lower_bound),
@@ -168,6 +171,7 @@ def __init__(self,
168171
label: Optional[Data] = None,
169172
missing: Optional[float] = None,
170173
weight: Optional[Data] = None,
174+
feature_weights: Optional[Data] = None,
171175
base_margin: Optional[Data] = None,
172176
label_lower_bound: Optional[Data] = None,
173177
label_upper_bound: Optional[Data] = None,
@@ -182,6 +186,7 @@ def __init__(self,
182186
self.label = label
183187
self.missing = missing
184188
self.weight = weight
189+
self.feature_weights = feature_weights
185190
self.base_margin = base_margin
186191
self.label_lower_bound = label_lower_bound
187192
self.label_upper_bound = label_upper_bound
@@ -248,8 +253,8 @@ def _split_dataframe(
248253
"""
249254
Split dataframe into
250255
251-
`features`, `labels`, `weight`, `base_margin`, `label_lower_bound`,
252-
`label_upper_bound`
256+
`features`, `labels`, `weight`, `feature_weights`, `base_margin`,
257+
`label_lower_bound`, `label_upper_bound`
253258
254259
"""
255260
# sort dataframe by qid if exists (required by DMatrix)
@@ -268,6 +273,11 @@ def _split_dataframe(
268273
if exclude:
269274
exclude_cols.add(exclude)
270275

276+
feature_weights, exclude = data_source.get_column(
277+
local_data, self.feature_weights)
278+
if exclude:
279+
exclude_cols.add(exclude)
280+
271281
qid, exclude = data_source.get_column(local_data, self.qid)
272282
if exclude:
273283
exclude_cols.add(exclude)
@@ -291,8 +301,8 @@ def _split_dataframe(
291301
if exclude_cols:
292302
x = x[[col for col in x.columns if col not in exclude_cols]]
293303

294-
return x, label, weight, base_margin, label_lower_bound, \
295-
label_upper_bound, qid
304+
return x, label, weight, feature_weights, base_margin, \
305+
label_lower_bound, label_upper_bound, qid
296306

297307
def load_data(self,
298308
num_actors: int,
@@ -380,7 +390,7 @@ def load_data(self,
380390
# yet. Instead, we'll be selecting the rows below.
381391
local_df = data_source.load_data(
382392
self.data, ignore=self.ignore, indices=None, **self.kwargs)
383-
x, y, w, b, ll, lu, qid = self._split_dataframe(
393+
x, y, w, fw, b, ll, lu, qid = self._split_dataframe(
384394
local_df, data_source=data_source)
385395

386396
if isinstance(x, list):
@@ -396,6 +406,7 @@ def load_data(self,
396406
"data": ray.put(x.iloc[indices]),
397407
"label": ray.put(y.iloc[indices] if y is not None else None),
398408
"weight": ray.put(w.iloc[indices] if w is not None else None),
409+
"feature_weights": ray.put(fw),
399410
"base_margin": ray.put(b.iloc[indices]
400411
if b is not None else None),
401412
"label_lower_bound": ray.put(ll.iloc[indices]
@@ -545,7 +556,7 @@ def load_data(self,
545556
indices=rank_shards,
546557
ignore=self.ignore,
547558
**self.kwargs)
548-
x, y, w, b, ll, lu, qid = self._split_dataframe(
559+
x, y, w, fw, b, ll, lu, qid = self._split_dataframe(
549560
local_df, data_source=data_source)
550561

551562
if isinstance(x, list):
@@ -557,16 +568,16 @@ def load_data(self,
557568
indices = _get_sharding_indices(sharding, rank, num_actors, n)
558569

559570
if not indices:
560-
x, y, w, b, ll, lu, qid = (None, None, None, None, None, None,
561-
None)
571+
x, y, w, fw, b, ll, lu, qid = (None, None, None, None, None,
572+
None, None, None)
562573
n = 0
563574
else:
564575
local_df = data_source.load_data(
565576
self.data,
566577
ignore=self.ignore,
567578
indices=indices,
568579
**self.kwargs)
569-
x, y, w, b, ll, lu, qid = self._split_dataframe(
580+
x, y, w, fw, b, ll, lu, qid = self._split_dataframe(
570581
local_df, data_source=data_source)
571582

572583
if isinstance(x, list):
@@ -579,6 +590,7 @@ def load_data(self,
579590
"data": ray.put(x),
580591
"label": ray.put(y),
581592
"weight": ray.put(w),
593+
"feature_weights": ray.put(fw),
582594
"base_margin": ray.put(b),
583595
"label_lower_bound": ray.put(ll),
584596
"label_upper_bound": ray.put(lu),
@@ -684,6 +696,7 @@ def __init__(self,
684696
data: Data,
685697
label: Optional[Data] = None,
686698
weight: Optional[Data] = None,
699+
feature_weights: Optional[Data] = None,
687700
base_margin: Optional[Data] = None,
688701
missing: Optional[float] = None,
689702
label_lower_bound: Optional[Data] = None,
@@ -739,6 +752,7 @@ def __init__(self,
739752
label=label,
740753
missing=missing,
741754
weight=weight,
755+
feature_weights=feature_weights,
742756
base_margin=base_margin,
743757
label_lower_bound=label_lower_bound,
744758
label_upper_bound=label_upper_bound,
@@ -755,6 +769,7 @@ def __init__(self,
755769
label=label,
756770
missing=missing,
757771
weight=weight,
772+
feature_weights=feature_weights,
758773
base_margin=base_margin,
759774
label_lower_bound=label_lower_bound,
760775
label_upper_bound=label_upper_bound,

xgboost_ray/tests/test_end_to_end.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,47 @@ def testRanking(self):
375375
auc_rec = evals_result["train"]["aucpr"]
376376
self.assertTrue((p <= q for p, q in zip(auc_rec, auc_rec[1:])))
377377

378+
@unittest.skipIf(xgb.__version__ < "1.3.0",
379+
f"not supported in xgb version {xgb.__version__}")
380+
def testFeatureWeightsParam(self):
381+
"""Test the feature_weights parameter for xgb version >= 1.3.0.
382+
Adapted from the official demo codes:
383+
http://xgboost.readthedocs.io/en/stable/python/examples/
384+
feature_weights.html"""
385+
386+
rng = np.random.RandomState(1994)
387+
388+
kRows = 1000
389+
kCols = 10
390+
391+
X = rng.randn(kRows, kCols)
392+
y = rng.randn(kRows)
393+
fw = np.ones(shape=(kCols, ))
394+
for i in range(kCols):
395+
fw[i] *= float(i)
396+
train_set = RayDMatrix(X, y, feature_weights=fw)
397+
398+
evals_result = {}
399+
bst = train(
400+
{
401+
"objective": "reg:squarederror",
402+
"eval_metric": ["rmse", "error"],
403+
"colsample_bynode": 0.1,
404+
},
405+
train_set,
406+
num_boost_round=250,
407+
evals_result=evals_result,
408+
evals=[(train_set, "train")],
409+
verbose_eval=False,
410+
ray_params=RayParams(
411+
num_actors=2, # Number of remote actors
412+
cpus_per_actor=1))
413+
414+
feature_map = bst.get_fscore()
415+
# feature zero has 0 weight
416+
self.assertTrue(feature_map.get("f0", None) is None)
417+
self.assertTrue(max(feature_map.values()) == feature_map.get("f9"))
418+
378419

379420
if __name__ == "__main__":
380421
import pytest

xgboost_ray/tests/test_matrix.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,15 @@ def testLegacyParams(self):
368368
label_lower_bound=label_lower_bound,
369369
label_upper_bound=label_upper_bound)
370370

371+
@unittest.skipIf(xgb.__version__ < "1.3.0",
372+
f"not supported in xgb version {xgb.__version__}")
373+
def testFeatureWeightsParam(self):
374+
"""Test the feature_weights parameter for xgb version >= 1.3.0"""
375+
in_x = self.x
376+
in_y = self.y
377+
feature_weights = np.arange(len(in_y))
378+
self._testMatrixCreation(in_x, in_y, feature_weights=feature_weights)
379+
371380
@unittest.skipIf("qid" not in inspect.signature(xgb.DMatrix).parameters,
372381
f"not supported in xgb version {xgb.__version__}")
373382
def testQidSortedBehaviorXGBoost(self):

0 commit comments

Comments
 (0)