Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 909848a

Browse files
authored
Update Ray Dataset with prediction example (#287)
Prediction with Ray Data doesn't work with distributed loading (see https://discuss.ray.io/t/raytaskerror-typeerror/11486/2). This PR adds a simple example on how to do batch inference with Ray Data. Ideally we can convert this automatically to Ray Data-based batch inference in predict() in the short term. In the long term, we should discontinue all non-ray-data APIs and move data source support to ray data instead (except for petastorm, most of them should be supported anyways).
1 parent 6c038a2 commit 909848a

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

xgboost_ray/examples/simple_ray_dataset.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import numpy as np
44
import pandas as pd
55
import ray
6+
from xgboost import DMatrix
67

78
from xgboost_ray import RayDMatrix, RayParams, train
89

910

1011
def main(cpus_per_actor, num_actors):
12+
np.random.seed(1234)
1113
# Generate dataset
1214
x = np.repeat(range(8), 16).reshape((32, 4))
1315
# Even numbers --> 0, odd numbers --> 1
@@ -22,16 +24,7 @@ def main(cpus_per_actor, num_actors):
2224
data.columns = [str(c) for c in data.columns]
2325
data["label"] = y
2426

25-
# There was recent API change - the first clause covers the new
26-
# and current Ray master API
27-
if hasattr(ray.data, "from_pandas_refs"):
28-
# Generate Ray dataset from 4 partitions
29-
ray_ds = ray.data.from_pandas(data).repartition(num_actors)
30-
else:
31-
# Split into 4 partitions
32-
partitions = [ray.put(part) for part in np.split(data, num_actors)]
33-
ray_ds = ray.data.from_pandas(partitions)
34-
27+
ray_ds = ray.data.from_pandas(data)
3528
train_set = RayDMatrix(ray_ds, "label")
3629

3730
evals_result = {}
@@ -62,6 +55,12 @@ def main(cpus_per_actor, num_actors):
6255
bst.save_model(model_path)
6356
print("Final training error: {:.4f}".format(evals_result["train"]["error"][-1]))
6457

58+
# Distributed prediction
59+
scored = ray_ds.drop_columns(["label"]).map_batches(
60+
lambda batch: {"pred": bst.predict(DMatrix(batch))}, batch_format="pandas"
61+
)
62+
print(scored.to_pandas())
63+
6564

6665
if __name__ == "__main__":
6766
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)