11from collections import defaultdict
2- from typing import Any , Optional , Sequence , Dict , Union , Tuple
2+ from typing import Any , List , Optional , Sequence , Dict , Union , Tuple
3+ import wrapt
34
45import pandas as pd
56
@@ -29,6 +30,14 @@ def _assert_dask_installed():
2930 "the code should not have been reached." )
3031
3132
33+ @wrapt .decorator
34+ def ensure_ray_dask_initialized (func : Any , instance : Any , args : List [Any ],
35+ kwargs : Any ) -> Any :
36+ _assert_dask_installed ()
37+ dask .config .set (scheduler = ray_dask_get )
38+ return func (* args , ** kwargs )
39+
40+
3241class Dask (DataSource ):
3342 """Read from distributed Dask dataframe.
3443
@@ -41,6 +50,7 @@ class Dask(DataSource):
4150 supports_central_loading = True
4251 supports_distributed_loading = True
4352
53+ @ensure_ray_dask_initialized
4454 @staticmethod
4555 def is_data_type (data : Any ,
4656 filetype : Optional [RayFileType ] = None ) -> bool :
@@ -51,6 +61,7 @@ def is_data_type(data: Any,
5161
5262 return isinstance (data , (DaskDataFrame , DaskSeries ))
5363
64+ @ensure_ray_dask_initialized
5465 @staticmethod
5566 def load_data (
5667 data : Any , # dask.pandas.DataFrame
@@ -79,6 +90,7 @@ def load_data(
7990
8091 return local_df
8192
93+ @ensure_ray_dask_initialized
8294 @staticmethod
8395 def convert_to_series (data : Any ) -> pd .Series :
8496 _assert_dask_installed ()
@@ -95,6 +107,7 @@ def convert_to_series(data: Any) -> pd.Series:
95107
96108 return DataSource .convert_to_series (data )
97109
110+ @ensure_ray_dask_initialized
98111 @staticmethod
99112 def get_actor_shards (
100113 data : Any , # dask.dataframe.DataFrame
@@ -109,6 +122,7 @@ def get_actor_shards(
109122
110123 return data , assign_partitions_to_actors (ip_to_parts , actor_rank_ips )
111124
125+ @ensure_ray_dask_initialized
112126 @staticmethod
113127 def get_n (data : Any ):
114128 """
0 commit comments