Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 2ff8fcc

Browse files
authored
Always detect Ray Dataset as distributed (#253)
Ensures that we always use distributed loading by default with Ray Datasets. Followup to ray-project/ray#31079 Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
1 parent 5f016ff commit 2ff8fcc

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

xgboost_ray/matrix.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
try:
2626
from ray.data.dataset import Dataset as RayDataset
27-
except (ImportError, ModuleNotFoundError):
27+
except ImportError:
2828

2929
class RayDataset:
3030
pass
@@ -916,6 +916,8 @@ def _can_load_distributed(source: Data) -> bool:
916916
return False
917917
elif Modin.is_data_type(source):
918918
return True
919+
elif isinstance(source, RayDataset):
920+
return True
919921
elif isinstance(source, str):
920922
# Strings should point to files or URLs
921923
# Usually parquet files point to directories
@@ -940,6 +942,8 @@ def _detect_distributed(source: Data) -> bool:
940942
return False
941943
if Modin.is_data_type(source):
942944
return True
945+
if isinstance(source, RayDataset):
946+
return True
943947
if isinstance(source, Iterable) and not isinstance(source, str) and \
944948
not (isinstance(source, Sequence) and isinstance(source[0], str)):
945949
# This is an iterable but not a Sequence of strings, and not a

xgboost_ray/tests/test_matrix.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
import pandas as pd
99

1010
import ray
11+
try:
12+
import ray.data as ray_data
13+
except (ImportError, ModuleNotFoundError):
14+
15+
ray_data = None
1116

1217
from xgboost_ray import RayDMatrix
1318
from xgboost_ray.matrix import (concat_dataframes, RayShardingMode,
@@ -29,7 +34,7 @@ def setUp(self):
2934

3035
@classmethod
3136
def setUpClass(cls):
32-
ray.init(local_mode=True)
37+
ray.init()
3338

3439
@classmethod
3540
def tearDownClass(cls):
@@ -315,6 +320,11 @@ def testDetectDistributed(self):
315320
mat = RayDMatrix([csv_file] * 3, lazy=True)
316321
self.assertTrue(mat.distributed)
317322

323+
if ray_data:
324+
ds = ray_data.read_parquet(parquet_file)
325+
mat = RayDMatrix(ds)
326+
self.assertTrue(mat.distributed)
327+
318328
def testTooManyActorsDistributed(self):
319329
"""Test error when too many actors are passed"""
320330
with self.assertRaises(RuntimeError):

0 commit comments

Comments
 (0)