Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit ece58d5

Browse files
authored
Better algorithm for batch sharding indices (#132)
* Better algorithm for batch sharding indices * Nit * Quick test * Lint * Lint
1 parent 7cbb134 commit ece58d5

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

xgboost_ray/matrix.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import glob
2-
import math
32
import uuid
43
from enum import Enum
54
from typing import Union, Optional, Tuple, Iterable, List, Dict, Sequence, \
@@ -885,10 +884,12 @@ def _get_sharding_indices(sharding: RayShardingMode, rank: int,
885884
num_actors: int, n: int):
886885
"""Return indices that belong to worker with rank `rank`"""
887886
if sharding == RayShardingMode.BATCH:
888-
start_index = int(rank * math.ceil(n / num_actors))
889-
end_index = int((rank + 1) * math.ceil(n / num_actors))
890-
end_index = min(end_index, n)
891-
indices = list(range(start_index, end_index))
887+
# based on numpy.array_split
888+
# github.com/numpy/numpy/blob/v1.21.0/numpy/lib/shape_base.py
889+
n_per_actor, extras = divmod(n, num_actors)
890+
div_points = np.array([0] + extras * [n_per_actor + 1] +
891+
(num_actors - extras) * [n_per_actor]).cumsum()
892+
indices = list(range(div_points[rank], div_points[rank + 1]))
892893
elif sharding == RayShardingMode.INTERLEAVED:
893894
indices = list(range(rank, n, num_actors))
894895
else:

xgboost_ray/tests/test_matrix.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import ray
99

1010
from xgboost_ray import RayDMatrix
11-
from xgboost_ray.matrix import concat_dataframes, RayShardingMode
11+
from xgboost_ray.matrix import (concat_dataframes, RayShardingMode,
12+
_get_sharding_indices)
1213

1314

1415
class XGBoostRayDMatrixTest(unittest.TestCase):
@@ -351,6 +352,12 @@ def testTooManyActorsCentral(self):
351352
with self.assertRaises(RuntimeError):
352353
RayDMatrix(data_df, num_actors=34, distributed=False)
353354

355+
def testBatchShardingAllActorsGetIndices(self):
356+
"""Check if all actors get indices with batch mode"""
357+
for i in range(16):
358+
self.assertTrue(
359+
_get_sharding_indices(RayShardingMode.BATCH, i, 16, 100))
360+
354361

355362
if __name__ == "__main__":
356363
import pytest

0 commit comments

Comments
 (0)