Skip to content

Commit b885026

Browse files
author
Zhi Lin
authored
Add an estimator API for XGBoost using Ray's XGBoostTrainer (#289)
* init Signed-off-by: Zhi Lin <zhi.lin@intel.com> * update test Signed-off-by: Zhi Lin <zhi.lin@intel.com> * fix Signed-off-by: Zhi Lin <zhi.lin@intel.com> * test why it fails on mac py3.9 Signed-off-by: Zhi Lin <zhi.lin@intel.com> * debug use_fs Signed-off-by: Zhi Lin <zhi.lin@intel.com> * test Signed-off-by: Zhi Lin <zhi.lin@intel.com> * test again Signed-off-by: Zhi Lin <zhi.lin@intel.com> * verbose test Signed-off-by: Zhi Lin <zhi.lin@intel.com> * delete cat for it wont get run Signed-off-by: Zhi Lin <zhi.lin@intel.com> * skip mpi test on macos Signed-off-by: Zhi Lin <zhi.lin@intel.com> * skip xgb tests on mac Signed-off-by: Zhi Lin <zhi.lin@intel.com> Signed-off-by: Zhi Lin <zhi.lin@intel.com>
1 parent 1315d93 commit b885026

File tree

5 files changed

+207
-3
lines changed

5 files changed

+207
-3
lines changed

.github/workflows/raydp.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ jobs:
102102
- name: Test with pytest
103103
run: |
104104
ray start --head --num-cpus 6
105-
pytest python/raydp/tests/ -m"not error_on_custom_resource"
106-
pytest python/raydp/tests/ -m"error_on_custom_resource"
105+
pytest python/raydp/tests/ -v -m"not error_on_custom_resource"
106+
pytest python/raydp/tests/ -v -m"error_on_custom_resource"
107107
ray stop --force
108108
- name: Test Examples
109109
run: |

python/raydp/tests/test_mpi.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
import sys
19-
19+
import platform
2020
import pytest
2121
import ray
2222
import ray._private.services
@@ -27,6 +27,8 @@
2727

2828
@pytest.mark.timeout(10)
2929
def test_mpi_start(ray_cluster):
30+
if platform.system() == "Darwin":
31+
pytest.skip("Skip MPI test on MacOS")
3032
if not ray.worker.global_worker.connected:
3133
pytest.skip("Skip MPI test if using ray client")
3234
job = create_mpi_job(job_name="test",
@@ -58,6 +60,8 @@ def func(context: WorkerContext):
5860

5961
@pytest.mark.timeout(10)
6062
def test_mpi_get_rank_address(ray_cluster):
63+
if platform.system() == "Darwin":
64+
pytest.skip("Skip MPI test on MacOS")
6165
if not ray.worker.global_worker.connected:
6266
pytest.skip("Skip MPI test if using ray client")
6367
with create_mpi_job(job_name="test",
@@ -74,6 +78,8 @@ def test_mpi_get_rank_address(ray_cluster):
7478

7579

7680
def test_mpi_with_script_prepare_fn(ray_cluster):
81+
if platform.system() == "Darwin":
82+
pytest.skip("Skip MPI test on MacOS")
7783
if not ray.worker.global_worker.connected:
7884
pytest.skip("Skip MPI test if using ray client")
7985
def script_prepare_fn(context: MPIJobContext):
@@ -99,6 +105,8 @@ def f(context: WorkerContext):
99105

100106

101107
def test_mpi_with_pg(ray_cluster):
108+
if platform.system() == "Darwin":
109+
pytest.skip("Skip MPI test on MacOS")
102110
if not ray.worker.global_worker.connected:
103111
pytest.skip("Skip MPI test if using ray client")
104112
pg = placement_group(bundles=[{"CPU": 2}], strategy="STRICT_SPREAD")

python/raydp/tests/test_xgboost.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import os
19+
import sys
20+
import shutil
21+
import platform
22+
import pytest
23+
import pyspark
24+
import numpy as np
25+
from pyspark.sql.functions import rand
26+
27+
from raydp.xgboost import XGBoostEstimator
28+
from raydp.utils import random_split
29+
30+
@pytest.mark.parametrize("use_fs_directory", [True, False])
31+
def test_xgb_estimator(spark_on_ray_small, use_fs_directory):
32+
if platform.system() == "Darwin":
33+
pytest.skip("Skip MPI test on MacOS")
34+
spark = spark_on_ray_small
35+
36+
# calculate z = 3 * x + 4 * y + 5
37+
df: pyspark.sql.DataFrame = spark.range(0, 100000)
38+
df = df.withColumn("x", rand() * 100) # add x column
39+
df = df.withColumn("y", rand() * 1000) # ad y column
40+
df = df.withColumn("z", df.x * 3 + df.y * 4 + rand() + 5) # ad z column
41+
df = df.select(df.x, df.y, df.z)
42+
43+
train_df, test_df = random_split(df, [0.7, 0.3])
44+
params = {}
45+
estimator = XGBoostEstimator(params, "z", resources_per_worker={"CPU": 1})
46+
if use_fs_directory:
47+
dir = os.path.dirname(os.path.realpath(__file__)) + "/test_xgboost"
48+
uri = "file://" + dir
49+
estimator.fit_on_spark(train_df, test_df, fs_directory=uri)
50+
else:
51+
estimator.fit_on_spark(train_df, test_df)
52+
print(estimator.get_model().inplace_predict(np.asarray([[1,2]])))
53+
if use_fs_directory:
54+
shutil.rmtree(dir)
55+
56+
if __name__ == '__main__':
57+
import ray, raydp
58+
ray.init(address="auto")
59+
spark = raydp.init_spark('test_xgboost', 1, 1, '500m')
60+
test_xgb_estimator(spark, True)

python/raydp/xgboost/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from .estimator import XGBoostEstimator
19+
20+
__all__ = ["XGBoostEstimator"]

python/raydp/xgboost/estimator.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Any, Callable, List, NoReturn, Optional, Union, Dict
19+
20+
from raydp.estimator import EstimatorInterface
21+
from raydp.spark.interfaces import SparkEstimatorInterface, DF, OPTIONAL_DF
22+
from raydp import stop_spark
23+
from raydp.spark import spark_dataframe_to_ray_dataset
24+
25+
import ray
26+
from ray.air.config import ScalingConfig, RunConfig, FailureConfig
27+
from ray.data.dataset import Dataset
28+
from ray.train.xgboost import XGBoostTrainer, XGBoostCheckpoint
29+
30+
class XGBoostEstimator(EstimatorInterface, SparkEstimatorInterface):
31+
def __init__(self,
32+
xgboost_params: Dict,
33+
label_column: str,
34+
dmatrix_params: Dict = None,
35+
num_workers: int = 1,
36+
resources_per_worker: Optional[Dict[str, float]] = None,
37+
shuffle: bool = True):
38+
"""
39+
:param xgboost_params: XGBoost training parameters.
40+
Refer to `XGBoost documentation <https://xgboost.readthedocs.io/>`_
41+
for a list of possible parameters.
42+
:param label_column: Name of the label column. A column with this name
43+
must be present in the training dataset passed to fit() later.
44+
:param dmatrix_params: Dict of ``dataset name:dict of kwargs`` passed to respective
45+
:class:`xgboost_ray.RayDMatrix` initializations, which in turn are passed
46+
to ``xgboost.DMatrix`` objects created on each worker. For example, this can
47+
be used to add sample weights with the ``weights`` parameter.
48+
:param num_workers: the number of workers to do the distributed training.
49+
:param resources_per_worker: the resources defined in this Dict will be reserved for
50+
each worker. The ``CPU`` and ``GPU`` keys (case-sensitive) can be defined to
51+
override the number of CPU/GPUs used by each worker.
52+
:param shuffle: whether to shuffle the data
53+
"""
54+
self._xgboost_params = xgboost_params
55+
self._label_column = label_column
56+
self._dmatrix_params = dmatrix_params
57+
self._num_workers = num_workers
58+
self._resources_per_worker = resources_per_worker
59+
self._shuffle = shuffle
60+
61+
def fit(self,
62+
train_ds: Dataset,
63+
evaluate_ds: Optional[Dataset] = None,
64+
max_retries=3) -> NoReturn:
65+
scaling_config = ScalingConfig(num_workers=self._num_workers,
66+
resources_per_worker=self._resources_per_worker)
67+
run_config = RunConfig(failure_config=FailureConfig(max_failures=max_retries))
68+
if self._shuffle:
69+
train_ds = train_ds.random_shuffle()
70+
if evaluate_ds:
71+
evaluate_ds = evaluate_ds.random_shuffle()
72+
datasets = {"train": train_ds}
73+
if evaluate_ds:
74+
datasets["evaluate"] = evaluate_ds
75+
trainer = XGBoostTrainer(scaling_config=scaling_config,
76+
datasets=datasets,
77+
label_column=self._label_column,
78+
params=self._xgboost_params,
79+
dmatrix_params=self._dmatrix_params,
80+
run_config=run_config)
81+
self._results = trainer.fit()
82+
83+
def fit_on_spark(self,
84+
train_df: DF,
85+
evaluate_df: OPTIONAL_DF = None,
86+
max_retries=3,
87+
fs_directory: Optional[str] = None,
88+
compression: Optional[str] = None,
89+
stop_spark_after_conversion=False):
90+
train_df = self._check_and_convert(train_df)
91+
evaluate_ds = None
92+
if fs_directory is not None:
93+
app_id = train_df.sql_ctx.sparkSession.sparkContext.applicationId
94+
path = fs_directory.rstrip("/") + f"/{app_id}"
95+
train_df.write.parquet(path+"/train", compression=compression)
96+
train_ds = ray.data.read_parquet(path+"/train")
97+
if evaluate_df is not None:
98+
evaluate_df = self._check_and_convert(evaluate_df)
99+
evaluate_df.write.parquet(path+"/test", compression=compression)
100+
evaluate_ds = ray.data.read_parquet(path+"/test")
101+
else:
102+
train_ds = spark_dataframe_to_ray_dataset(train_df,
103+
parallelism=self._num_workers,
104+
_use_owner=stop_spark_after_conversion)
105+
if evaluate_df is not None:
106+
evaluate_df = self._check_and_convert(evaluate_df)
107+
evaluate_ds = spark_dataframe_to_ray_dataset(evaluate_df,
108+
parallelism=self._num_workers,
109+
_use_owner=stop_spark_after_conversion)
110+
if stop_spark_after_conversion:
111+
stop_spark(cleanup_data=False)
112+
return self.fit(
113+
train_ds, evaluate_ds, max_retries)
114+
115+
def get_model(self):
116+
return XGBoostCheckpoint.from_checkpoint(self._results.checkpoint).get_model()

0 commit comments

Comments
 (0)