Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 8668a77

Browse files
authored
Repartition Ray dataset if number of shards is too small (#283)
Currently we throw an error when the number of partitions in a data source is too small for the number of workers. However, in the case of Ray datasets, we can actually repartition the dataset ourselves. This will also ensure our quickstart examples, such as in https://docs.ray.io/en/latest/train/train.html#quick-start-to-distributed-training-with-ray-train will work out of the box.
1 parent b45c5d9 commit 8668a77

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

xgboost_ray/data_sources/data_source.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class methods are called directly.
3434

3535
supports_central_loading = True
3636
supports_distributed_loading = False
37+
needs_partitions = True
3738

3839
@staticmethod
3940
def is_data_type(data: Any, filetype: Optional[RayFileType] = None) -> bool:

xgboost_ray/data_sources/ray_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class RayDataset(DataSource):
3434

3535
supports_central_loading = True
3636
supports_distributed_loading = True
37+
needs_partitions = False
3738

3839
@staticmethod
3940
def is_data_type(data: Any, filetype: Optional[RayFileType] = None) -> bool:
@@ -102,7 +103,7 @@ def get_actor_shards(
102103
}
103104

104105
@staticmethod
105-
def get_n(data: Any):
106+
def get_n(data: "ray.data.dataset.Dataset"):
106107
"""
107108
Return number of distributed blocks.
108109
"""

xgboost_ray/matrix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def load_data(
430430
data_source = self.get_data_source()
431431

432432
max_num_shards = self._cached_n or data_source.get_n(self.data)
433-
if num_actors > max_num_shards:
433+
if num_actors > max_num_shards and data_source.needs_partitions:
434434
raise RuntimeError(
435435
f"Trying to shard data for {num_actors} actors, but the "
436436
f"maximum number of shards (i.e. the number of data rows) "
@@ -565,7 +565,7 @@ def assert_enough_shards_for_actors(self, num_actors: int):
565565
return
566566

567567
max_num_shards = self._cached_n or data_source.get_n(self.data)
568-
if num_actors > max_num_shards:
568+
if num_actors > max_num_shards and data_source.needs_partitions:
569569
raise RuntimeError(
570570
f"Trying to shard data for {num_actors} actors, but the "
571571
f"maximum number of shards is {max_num_shards}. If you "

0 commit comments

Comments
 (0)