Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit e904925

Browse files
authored
Add multi label support v2 (#306)
* Add multi label support to xgboost ray * fix lint * add a missing change * add another missing change * fix lint
1 parent 5a840af commit e904925

File tree

3 files changed

+73
-9
lines changed

3 files changed

+73
-9
lines changed

xgboost_ray/data_sources/data_source.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
2+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
33

44
import pandas as pd
55
from ray.actor import ActorHandle
@@ -118,12 +118,12 @@ def convert_to_series(data: Any) -> pd.Series:
118118
@classmethod
119119
def get_column(
120120
cls, data: pd.DataFrame, column: Any
121-
) -> Tuple[pd.Series, Optional[str]]:
121+
) -> Tuple[pd.Series, Optional[Union[str, List]]]:
122122
"""Helper method wrapping around convert to series.
123123
124124
This method should usually not be overwritten.
125125
"""
126-
if isinstance(column, str):
126+
if isinstance(column, str) or isinstance(column, List):
127127
return data[column], column
128128
elif column is not None:
129129
return cls.convert_to_series(column), None

xgboost_ray/matrix.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,10 @@ def _split_dataframe(
307307

308308
label, exclude = data_source.get_column(local_data, self.label)
309309
if exclude:
310-
exclude_cols.add(exclude)
310+
if isinstance(exclude, List):
311+
exclude_cols.update(exclude)
312+
else:
313+
exclude_cols.add(exclude)
311314

312315
weight, exclude = data_source.get_column(local_data, self.weight)
313316
if exclude:
@@ -406,7 +409,11 @@ def get_data_source(self) -> Type[DataSource]:
406409
): # noqa: E721:
407410
# Label is an object of a different type than the main data.
408411
# We have to make sure they are compatible
409-
if not data_source.is_data_type(self.label):
412+
# if it's a parquet data source and label is a list,
413+
# then we consider it a multi-label data
414+
if not data_source.is_data_type(self.label) and not (
415+
isinstance(self.label, List) and data_source.__name__ == "Parquet"
416+
):
410417
raise ValueError(
411418
"The passed `data` and `label` types are not compatible."
412419
"\nFIX THIS by passing the same types to the "
@@ -521,7 +528,11 @@ def get_data_source(self) -> Type[DataSource]:
521528
f"RayDMatrix."
522529
)
523530

524-
if self.label is not None and not isinstance(self.label, str):
531+
if (
532+
self.label is not None
533+
and not isinstance(self.label, str)
534+
and not isinstance(self.label, List)
535+
):
525536
raise ValueError(
526537
f"Invalid `label` value for distributed datasets: "
527538
f"{self.label}. Only strings are supported. "

xgboost_ray/tests/test_matrix.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ def setUp(self):
3333
* repeat
3434
)
3535
self.y = np.array([0, 1, 2, 3] * repeat)
36+
self.multi_y = np.array(
37+
[
38+
[1, 0, 0, 0],
39+
[0, 1, 0, 0],
40+
[0, 0, 1, 1],
41+
[0, 0, 1, 0],
42+
]
43+
* repeat
44+
)
3645

3746
@classmethod
3847
def setUpClass(cls):
@@ -62,7 +71,7 @@ def testColumnOrdering(self):
6271

6372
assert data.columns.tolist() == cols[:-1]
6473

65-
def _testMatrixCreation(self, in_x, in_y, **kwargs):
74+
def _testMatrixCreation(self, in_x, in_y, multi_label=False, **kwargs):
6675
if "sharding" not in kwargs:
6776
kwargs["sharding"] = RayShardingMode.BATCH
6877
mat = RayDMatrix(in_x, in_y, **kwargs)
@@ -81,7 +90,10 @@ def _load_data(params):
8190
x, y = _load_data(params)
8291

8392
self.assertTrue(np.allclose(self.x, x))
84-
self.assertTrue(np.allclose(self.y, y))
93+
if multi_label:
94+
self.assertTrue(np.allclose(self.multi_y, y))
95+
else:
96+
self.assertTrue(np.allclose(self.y, y))
8597

8698
# Multi actor check
8799
mat = RayDMatrix(in_x, in_y, **kwargs)
@@ -95,7 +107,10 @@ def _load_data(params):
95107
x2, y2 = _load_data(params)
96108

97109
self.assertTrue(np.allclose(self.x, concat_dataframes([x1, x2])))
98-
self.assertTrue(np.allclose(self.y, concat_dataframes([y1, y2])))
110+
if multi_label:
111+
self.assertTrue(np.allclose(self.multi_y, concat_dataframes([y1, y2])))
112+
else:
113+
self.assertTrue(np.allclose(self.y, concat_dataframes([y1, y2])))
99114

100115
def testFromNumpy(self):
101116
in_x = self.x
@@ -276,6 +291,22 @@ def testFromMultiCSVString(self):
276291
[data_file_1, data_file_2], "label", distributed=True
277292
)
278293

294+
def testFromParquetStringMultiLabel(self):
295+
with tempfile.TemporaryDirectory() as dir:
296+
data_file = os.path.join(dir, "data.parquet")
297+
298+
data_df = pd.DataFrame(self.x, columns=["a", "b", "c", "d"])
299+
labels = [f"label_{label}" for label in range(4)]
300+
data_df[labels] = self.multi_y
301+
data_df.to_parquet(data_file)
302+
303+
self._testMatrixCreation(
304+
data_file, labels, multi_label=True, distributed=False
305+
)
306+
self._testMatrixCreation(
307+
data_file, labels, multi_label=True, distributed=True
308+
)
309+
279310
def testFromParquetString(self):
280311
with tempfile.TemporaryDirectory() as dir:
281312
data_file = os.path.join(dir, "data.parquet")
@@ -287,6 +318,28 @@ def testFromParquetString(self):
287318
self._testMatrixCreation(data_file, "label", distributed=False)
288319
self._testMatrixCreation(data_file, "label", distributed=True)
289320

321+
def testFromMultiParquetStringMultiLabel(self):
322+
with tempfile.TemporaryDirectory() as dir:
323+
data_file_1 = os.path.join(dir, "data_1.parquet")
324+
data_file_2 = os.path.join(dir, "data_2.parquet")
325+
326+
data_df = pd.DataFrame(self.x, columns=["a", "b", "c", "d"])
327+
labels = [f"label_{label}" for label in range(4)]
328+
data_df[labels] = self.multi_y
329+
330+
df_1 = data_df[0 : len(data_df) // 2]
331+
df_2 = data_df[len(data_df) // 2 :]
332+
333+
df_1.to_parquet(data_file_1)
334+
df_2.to_parquet(data_file_2)
335+
336+
self._testMatrixCreation(
337+
[data_file_1, data_file_2], labels, multi_label=True, distributed=False
338+
)
339+
self._testMatrixCreation(
340+
[data_file_1, data_file_2], labels, multi_label=True, distributed=True
341+
)
342+
290343
def testFromMultiParquetString(self):
291344
with tempfile.TemporaryDirectory() as dir:
292345
data_file_1 = os.path.join(dir, "data_1.parquet")

0 commit comments

Comments
 (0)