Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit e68ff63

Browse files
authored
Fix BATCH sharding (#123)
1 parent 57749f9 commit e68ff63

File tree

3 files changed

+51
-9
lines changed

3 files changed

+51
-9
lines changed

xgboost_ray/data_sources/petastorm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ def load_data(data: Union[str, Sequence[str]],
6565
**kwargs) -> pd.DataFrame:
6666
_assert_petastorm_installed()
6767
with petastorm.make_batch_reader(data) as reader:
68-
shards = [pd.DataFrame(batch._asdict()) for batch in reader]
68+
shards = [
69+
pd.DataFrame(batch._asdict()) for i, batch in enumerate(reader)
70+
if not indices or i in indices
71+
]
6972

7073
local_df = pd.concat(shards, copy=False)
7174

xgboost_ray/matrix.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -885,8 +885,9 @@ def _get_sharding_indices(sharding: RayShardingMode, rank: int,
885885
num_actors: int, n: int):
886886
"""Return indices that belong to worker with rank `rank`"""
887887
if sharding == RayShardingMode.BATCH:
888-
start_index = int(math.floor(rank / num_actors) * n)
889-
end_index = int(math.floor(rank + 1 / num_actors) * n)
888+
start_index = int(rank * math.ceil(n / num_actors))
889+
end_index = int((rank + 1) * math.ceil(n / num_actors))
890+
end_index = min(end_index, n)
890891
indices = list(range(start_index, end_index))
891892
elif sharding == RayShardingMode.INTERLEAVED:
892893
indices = list(range(rank, n, num_actors))
@@ -913,7 +914,7 @@ def combine_data(sharding: RayShardingMode, data: Iterable) -> np.ndarray:
913914
if data[0].ndim == 1:
914915
# most common case
915916
if sharding == RayShardingMode.BATCH:
916-
res = np.ravel(data)
917+
res = np.concatenate(data)
917918
elif sharding == RayShardingMode.INTERLEAVED:
918919
# Sometimes the lengths are off by 1 for uneven divisions
919920
min_len = min(len(d) for d in data)

xgboost_ray/tests/test_end_to_end.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import ray
1010
from ray.exceptions import RayActorError, RayTaskError
1111

12-
from xgboost_ray import RayParams, train, RayDMatrix, predict
12+
from xgboost_ray import RayParams, train, RayDMatrix, predict, RayShardingMode
1313
from xgboost_ray.main import RayXGBoostTrainingError
1414
from xgboost_ray.callback import DistributedCallback
1515
from xgboost_ray.tests.utils import get_num_trees
@@ -134,20 +134,58 @@ def testHalfTraining(self):
134134
pred_test = bst.predict(test_X)
135135
self.assertSequenceEqual(test_y_second, list(pred_test))
136136

137-
def testJointTraining(self):
137+
def _testJointTraining(self,
138+
sharding=RayShardingMode.INTERLEAVED,
139+
softprob=False):
138140
"""Train with Ray. The data will be split, but the trees
139141
should be combined together and find the true model."""
140-
ray.init(num_cpus=2, num_gpus=0)
142+
params = self.params.copy()
143+
if softprob:
144+
params["objective"] = "multi:softprob"
141145

142146
bst = train(
143-
self.params,
144-
RayDMatrix(self.x, self.y),
147+
params,
148+
RayDMatrix(self.x, self.y, sharding=sharding),
145149
ray_params=RayParams(num_actors=2))
146150

147151
x_mat = xgb.DMatrix(self.x)
148152
pred_y = bst.predict(x_mat)
153+
if softprob:
154+
pred_y = np.argmax(pred_y, axis=1)
155+
pred_y = pred_y.astype(int)
149156
self.assertSequenceEqual(list(self.y), list(pred_y))
150157

158+
x_mat = RayDMatrix(self.x, sharding=sharding)
159+
pred_y = predict(bst, x_mat, ray_params=RayParams(num_actors=2))
160+
if softprob:
161+
pred_y = np.argmax(pred_y, axis=1)
162+
pred_y = pred_y.astype(int)
163+
self.assertSequenceEqual(list(self.y), list(pred_y))
164+
165+
# try on an odd number of rows
166+
bst = train(
167+
params,
168+
RayDMatrix(self.x[:-1], self.y[:-1], sharding=sharding),
169+
ray_params=RayParams(num_actors=2))
170+
171+
x_mat = RayDMatrix(self.x[:-1], sharding=sharding)
172+
pred_y = predict(bst, x_mat, ray_params=RayParams(num_actors=2))
173+
if softprob:
174+
pred_y = np.argmax(pred_y, axis=1)
175+
pred_y = pred_y.astype(int)
176+
self.assertSequenceEqual(list(self.y[:-1]), list(pred_y))
177+
178+
def testJointTrainingInterleaved(self):
179+
ray.init(num_cpus=2, num_gpus=0)
180+
self._testJointTraining(sharding=RayShardingMode.INTERLEAVED)
181+
self._testJointTraining(
182+
sharding=RayShardingMode.INTERLEAVED, softprob=True)
183+
184+
def testJointTrainingBatch(self):
185+
ray.init(num_cpus=2, num_gpus=0)
186+
self._testJointTraining(sharding=RayShardingMode.BATCH)
187+
self._testJointTraining(sharding=RayShardingMode.BATCH, softprob=True)
188+
151189
def testTrainPredict(self,
152190
init=True,
153191
remote=None,

0 commit comments

Comments
 (0)