Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 1e16381

Browse files
authored
Ensure that Dask Scheduler is set for Ray-on-Dask (#150)
* wrapper * bump wrapt * lint
1 parent b2b75f6 commit 1e16381

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"distributed computing framework Ray.",
1111
url="https://github.com/ray-project/xgboost_ray",
1212
install_requires=[
13-
"xgboost>=0.90", "ray", "numpy>=1.16,<1.20", "pandas", "pyarrow<5.0.0"
13+
"xgboost>=0.90", "ray", "numpy>=1.16,<1.20", "pandas", "pyarrow<5.0.0",
14+
"wrapt>=1.12.1"
1415
])
15-
# pyarrow<5.0.0 pinned until petastorm is updated
16+
# pyarrow<5.0.0 pinned until petastorm is updated

xgboost_ray/data_sources/dask.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import defaultdict
2-
from typing import Any, Optional, Sequence, Dict, Union, Tuple
2+
from typing import Any, List, Optional, Sequence, Dict, Union, Tuple
3+
import wrapt
34

45
import pandas as pd
56

@@ -29,6 +30,14 @@ def _assert_dask_installed():
2930
"the code should not have been reached.")
3031

3132

33+
@wrapt.decorator
34+
def ensure_ray_dask_initialized(func: Any, instance: Any, args: List[Any],
35+
kwargs: Any) -> Any:
36+
_assert_dask_installed()
37+
dask.config.set(scheduler=ray_dask_get)
38+
return func(*args, **kwargs)
39+
40+
3241
class Dask(DataSource):
3342
"""Read from distributed Dask dataframe.
3443
@@ -41,6 +50,7 @@ class Dask(DataSource):
4150
supports_central_loading = True
4251
supports_distributed_loading = True
4352

53+
@ensure_ray_dask_initialized
4454
@staticmethod
4555
def is_data_type(data: Any,
4656
filetype: Optional[RayFileType] = None) -> bool:
@@ -51,6 +61,7 @@ def is_data_type(data: Any,
5161

5262
return isinstance(data, (DaskDataFrame, DaskSeries))
5363

64+
@ensure_ray_dask_initialized
5465
@staticmethod
5566
def load_data(
5667
data: Any, # dask.pandas.DataFrame
@@ -79,6 +90,7 @@ def load_data(
7990

8091
return local_df
8192

93+
@ensure_ray_dask_initialized
8294
@staticmethod
8395
def convert_to_series(data: Any) -> pd.Series:
8496
_assert_dask_installed()
@@ -95,6 +107,7 @@ def convert_to_series(data: Any) -> pd.Series:
95107

96108
return DataSource.convert_to_series(data)
97109

110+
@ensure_ray_dask_initialized
98111
@staticmethod
99112
def get_actor_shards(
100113
data: Any, # dask.dataframe.DataFrame
@@ -109,6 +122,7 @@ def get_actor_shards(
109122

110123
return data, assign_partitions_to_actors(ip_to_parts, actor_rank_ips)
111124

125+
@ensure_ray_dask_initialized
112126
@staticmethod
113127
def get_n(data: Any):
114128
"""

0 commit comments

Comments
 (0)