Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 733c24c

Browse files
Re-factor data loading structure (#66)
* Re-factor data loading structure * Better tests, better documentation * Update xgboost_ray/data_sources/data_source.py Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Update xgboost_ray/data_sources/data_source.py Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Update xgboost_ray/data_sources/data_source.py Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Update docs * Resolve breaking api change * Resolve breaking api change (cont) * Resolve breaking api change (cont) Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
1 parent e7867d9 commit 733c24c

File tree

12 files changed

+614
-245
lines changed

12 files changed

+614
-245
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from xgboost_ray.data_sources.data_source import DataSource, RayFileType
2+
from xgboost_ray.data_sources.numpy import Numpy
3+
from xgboost_ray.data_sources.pandas import Pandas
4+
from xgboost_ray.data_sources.modin import Modin
5+
from xgboost_ray.data_sources.ml_dataset import MLDataset
6+
from xgboost_ray.data_sources.petastorm import Petastorm
7+
from xgboost_ray.data_sources.csv import CSV
8+
from xgboost_ray.data_sources.parquet import Parquet
9+
10+
data_sources = [Numpy, Pandas, Modin, MLDataset, Petastorm, CSV, Parquet]
11+
12+
__all__ = [
13+
"DataSource", "RayFileType", "Numpy", "Pandas", "Modin", "MLDataset",
14+
"Petastorm", "CSV", "Parquet"
15+
]

xgboost_ray/data_sources/csv.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Any, Optional, Sequence, Iterable, Union
2+
3+
import pandas as pd
4+
5+
from xgboost_ray.data_sources.data_source import DataSource, RayFileType
6+
from xgboost_ray.data_sources.pandas import Pandas
7+
8+
9+
class CSV(DataSource):
10+
"""Read one or many CSV files."""
11+
supports_central_loading = True
12+
supports_distributed_loading = True
13+
14+
@staticmethod
15+
def is_data_type(data: Any,
16+
filetype: Optional[RayFileType] = None) -> bool:
17+
return filetype == RayFileType.CSV
18+
19+
@staticmethod
20+
def get_filetype(data: Any) -> Optional[RayFileType]:
21+
if data.endswith(".csv") or data.endswith("csv.gz"):
22+
return RayFileType.CSV
23+
return None
24+
25+
@staticmethod
26+
def load_data(data: Union[str, Sequence[str]],
27+
ignore: Optional[Sequence[str]] = None,
28+
indices: Optional[Sequence[int]] = None,
29+
**kwargs):
30+
if isinstance(data, Iterable) and not isinstance(data, str):
31+
shards = []
32+
33+
for i, shard in enumerate(data):
34+
if indices and i not in indices:
35+
continue
36+
shard_df = pd.read_csv(shard, **kwargs)
37+
shards.append(Pandas.load_data(shard_df, ignore=ignore))
38+
return pd.concat(shards, copy=False)
39+
else:
40+
local_df = pd.read_csv(data, **kwargs)
41+
return Pandas.load_data(local_df, ignore=ignore)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from enum import Enum
2+
from typing import Any, Optional, Sequence, Tuple
3+
4+
import pandas as pd
5+
6+
7+
class RayFileType(Enum):
8+
"""Enum for different file types (used for overrides)."""
9+
CSV = 1
10+
PARQUET = 2
11+
PETASTORM = 3
12+
13+
14+
class DataSource:
15+
"""Abstract class for data sources.
16+
17+
xgboost_ray supports reading from various sources, such as files
18+
(e.g. CSV, Parquet) or distributed datasets (Ray MLDataset, Modin).
19+
20+
This abstract class defines an interface to read from these sources.
21+
New data sources can be added by implementing this interface.
22+
23+
``DataSource`` classes are not instantiated. Instead, static and
24+
class methods are called directly.
25+
"""
26+
supports_central_loading = True
27+
supports_distributed_loading = False
28+
29+
@staticmethod
30+
def is_data_type(data: Any,
31+
filetype: Optional[RayFileType] = None) -> bool:
32+
"""Check if the supplied data matches this data source.
33+
34+
Args:
35+
data (Any): Dataset.
36+
filetype (Optional[RayFileType]): RayFileType of the provided
37+
dataset. Some DataSource implementations might require
38+
that this is explicitly set (e.g. if multiple sources can
39+
read CSV files).
40+
41+
Returns:
42+
Boolean indicating if this data source belongs to/is compatible
43+
with the data.
44+
"""
45+
return False
46+
47+
@staticmethod
48+
def get_filetype(data: Any) -> Optional[RayFileType]:
49+
"""Method to help infer the filetype.
50+
51+
Returns None if the supplied data type (usually a filename)
52+
is not covered by this data source, otherwise the filetype
53+
is returned.
54+
55+
Args:
56+
data (Any): Data set
57+
58+
Returns:
59+
RayFileType or None.
60+
"""
61+
return None
62+
63+
@staticmethod
64+
def load_data(data: Any,
65+
ignore: Optional[Sequence[str]] = None,
66+
indices: Optional[Sequence[int]] = None,
67+
**kwargs) -> pd.DataFrame:
68+
"""
69+
Load data into a pandas dataframe.
70+
71+
Ignore specific columns, and optionally select specific indices.
72+
73+
Args:
74+
data (Any): Input data
75+
ignore (Optional[Sequence[str]]): Column names to ignore
76+
indices (Optional[Sequence[int]]): Indices to select. What an
77+
index indicates depends on the data source.
78+
79+
Returns:
80+
Pandas DataFrame.
81+
"""
82+
raise NotImplementedError
83+
84+
@staticmethod
85+
def convert_to_series(data: Any) -> pd.Series:
86+
"""Convert data from the data source type to a pandas series"""
87+
if isinstance(data, pd.DataFrame):
88+
return pd.Series(data.squeeze())
89+
90+
if not isinstance(data, pd.Series):
91+
return pd.Series(data)
92+
93+
return data
94+
95+
@classmethod
96+
def get_column(cls, data: pd.DataFrame,
97+
column: Any) -> Tuple[pd.Series, Optional[str]]:
98+
"""Helper method wrapping around convert to series.
99+
100+
This method should usually not be overwritten.
101+
"""
102+
if isinstance(column, str):
103+
return data[column], column
104+
elif column is not None:
105+
return cls.convert_to_series(column), None
106+
return column, None
107+
108+
@staticmethod
109+
def get_n(data: Any):
110+
"""Get length of data source partitions for sharding."""
111+
return len(list(data))
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Any, Optional, Sequence, List
2+
3+
import pandas as pd
4+
from ray.util.data import MLDataset as MLDatasetType
5+
from xgboost_ray.data_sources.data_source import DataSource, RayFileType
6+
7+
8+
class MLDataset(DataSource):
9+
"""Read from distributed Ray MLDataset.
10+
11+
The Ray MLDataset is a distributed dataset based on Ray's
12+
`parallel iterators <https://docs.ray.io/en/master/iter.html>`_.
13+
14+
Shards of the MLDataset can be stored on different nodes, making
15+
it suitable for distributed loading.
16+
"""
17+
supports_central_loading = True
18+
supports_distributed_loading = True
19+
20+
@staticmethod
21+
def is_data_type(data: Any,
22+
filetype: Optional[RayFileType] = None) -> bool:
23+
return isinstance(data, MLDatasetType)
24+
25+
@staticmethod
26+
def load_data(data: MLDatasetType,
27+
ignore: Optional[Sequence[str]] = None,
28+
indices: Optional[Sequence[int]] = None,
29+
**kwargs):
30+
indices = indices or list(range(0, data.num_shards()))
31+
32+
shards: List[pd.DataFrame] = [
33+
pd.concat(data.get_shard(i), copy=False) for i in indices
34+
]
35+
36+
# Concat all shards
37+
local_df = pd.concat(shards, copy=False)
38+
39+
if ignore:
40+
local_df = local_df[local_df.columns.difference(ignore)]
41+
42+
return local_df
43+
44+
@staticmethod
45+
def get_n(data: MLDatasetType):
46+
return data.num_shards()

xgboost_ray/data_sources/modin.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import Any, Optional, Sequence
2+
3+
from xgboost_ray.data_sources.data_source import DataSource, RayFileType
4+
5+
import pandas as pd
6+
7+
try:
8+
import modin # noqa: F401
9+
MODIN_INSTALLED = True
10+
except ImportError:
11+
MODIN_INSTALLED = False
12+
13+
14+
def _assert_modin_installed():
15+
if not MODIN_INSTALLED:
16+
raise RuntimeError(
17+
"Tried to use Modin as a data source, but modin is not "
18+
"installed. This function shouldn't have been called. "
19+
"\nFIX THIS by installing modin: `pip install modin`. "
20+
"\nPlease also raise an issue on our GitHub: "
21+
"https://github.com/ray-project/xgboost_ray as this part of "
22+
"the code should not have been reached.")
23+
24+
25+
class Modin(DataSource):
26+
"""Read from distributed Modin dataframe.
27+
28+
`Modin <https://github.com/modin-project/modin>`_ is a distributed
29+
drop-in replacement for pandas supporting Ray as a backend.
30+
31+
Modin dataframes are stored on multiple actors, making them
32+
suitable for distributed loading.
33+
"""
34+
35+
@staticmethod
36+
def is_data_type(data: Any,
37+
filetype: Optional[RayFileType] = None) -> bool:
38+
if not MODIN_INSTALLED:
39+
return False
40+
from modin.pandas import DataFrame as ModinDataFrame, \
41+
Series as ModinSeries
42+
43+
return isinstance(data, (ModinDataFrame, ModinSeries))
44+
45+
@staticmethod
46+
def load_data(
47+
data: Any, # modin.pandas.DataFrame
48+
ignore: Optional[Sequence[str]] = None,
49+
indices: Optional[Sequence[int]] = None,
50+
**kwargs) -> pd.DataFrame:
51+
_assert_modin_installed()
52+
local_df = data
53+
if indices:
54+
local_df = local_df.iloc(indices)
55+
56+
local_df = local_df._to_pandas()
57+
58+
if ignore:
59+
local_df = local_df[local_df.columns.difference(ignore)]
60+
61+
return local_df
62+
63+
@staticmethod
64+
def convert_to_series(data: Any) -> pd.Series:
65+
_assert_modin_installed()
66+
from modin.pandas import DataFrame as ModinDataFrame, \
67+
Series as ModinSeries
68+
69+
if isinstance(data, ModinDataFrame):
70+
return pd.Series(data._to_pandas().squeeze())
71+
elif isinstance(data, ModinSeries):
72+
return data._to_pandas()
73+
74+
return DataSource.convert_to_series(data)

xgboost_ray/data_sources/numpy.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Any, Optional, Sequence
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from xgboost_ray.data_sources.data_source import DataSource, RayFileType
7+
from xgboost_ray.data_sources.pandas import Pandas
8+
9+
10+
class Numpy(DataSource):
11+
"""Read from numpy arrays."""
12+
13+
@staticmethod
14+
def is_data_type(data: Any,
15+
filetype: Optional[RayFileType] = None) -> bool:
16+
return isinstance(data, np.ndarray)
17+
18+
@staticmethod
19+
def load_data(data: np.ndarray,
20+
ignore: Optional[Sequence[str]] = None,
21+
indices: Optional[Sequence[int]] = None,
22+
**kwargs) -> pd.DataFrame:
23+
local_df = pd.DataFrame(
24+
data, columns=[f"f{i}" for i in range(data.shape[1])])
25+
return Pandas.load_data(local_df, ignore=ignore, indices=indices)

xgboost_ray/data_sources/pandas.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Any, Optional, Sequence
2+
3+
import pandas as pd
4+
5+
from xgboost_ray.data_sources.data_source import DataSource, RayFileType
6+
7+
8+
class Pandas(DataSource):
9+
"""Read from pandas dataframes and series."""
10+
11+
@staticmethod
12+
def is_data_type(data: Any,
13+
filetype: Optional[RayFileType] = None) -> bool:
14+
return isinstance(data, (pd.DataFrame, pd.Series))
15+
16+
@staticmethod
17+
def load_data(data: Any,
18+
ignore: Optional[Sequence[str]] = None,
19+
indices: Optional[Sequence[int]] = None,
20+
**kwargs) -> pd.DataFrame:
21+
local_df = data
22+
23+
if ignore:
24+
local_df = local_df[local_df.columns.difference(ignore)]
25+
26+
if indices:
27+
return local_df.iloc[indices]
28+
29+
return local_df
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Any, Optional, Sequence, Iterable, Union
2+
3+
import pandas as pd
4+
5+
from xgboost_ray.data_sources.data_source import DataSource, RayFileType
6+
from xgboost_ray.data_sources.pandas import Pandas
7+
8+
9+
class Parquet(DataSource):
10+
"""Read one or many Parquet files."""
11+
supports_central_loading = True
12+
supports_distributed_loading = True
13+
14+
@staticmethod
15+
def is_data_type(data: Any,
16+
filetype: Optional[RayFileType] = None) -> bool:
17+
return filetype == RayFileType.PARQUET
18+
19+
@staticmethod
20+
def get_filetype(data: Any) -> Optional[RayFileType]:
21+
if data.endswith(".parquet"):
22+
return RayFileType.PARQUET
23+
return None
24+
25+
@staticmethod
26+
def load_data(data: Union[str, Sequence[str]],
27+
ignore: Optional[Sequence[str]] = None,
28+
indices: Optional[Sequence[int]] = None,
29+
**kwargs) -> pd.DataFrame:
30+
if isinstance(data, Iterable) and not isinstance(data, str):
31+
shards = []
32+
33+
for i, shard in enumerate(data):
34+
if indices and i not in indices:
35+
continue
36+
37+
shard_df = pd.read_parquet(shard, **kwargs)
38+
shards.append(Pandas.load_data(shard_df, ignore=ignore))
39+
return pd.concat(shards, copy=False)
40+
else:
41+
local_df = pd.read_parquet(data, **kwargs)
42+
return Pandas.load_data(local_df, ignore=ignore)

0 commit comments

Comments
 (0)