Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Commit 55ad540

Browse files
authored
Fix num_threads param, check for n_jobs param (#118)
1 parent 2501b4c commit 55ad540

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

.github/workflows/test.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
3131
test_linux_ray_master:
3232
runs-on: ubuntu-latest
33-
timeout-minutes: 30
33+
timeout-minutes: 40
3434
strategy:
3535
matrix:
3636
python-version: [3.6.9, 3.7, 3.8]
@@ -68,7 +68,7 @@ jobs:
6868
6969
test_linux_ray_release:
7070
runs-on: ubuntu-latest
71-
timeout-minutes: 30
71+
timeout-minutes: 40
7272
strategy:
7373
matrix:
7474
python-version: [3.6.9, 3.7, 3.8]
@@ -101,7 +101,7 @@ jobs:
101101
# Test compatibility when some optional libraries are missing
102102
# Test runs on latest ray release
103103
runs-on: ubuntu-latest
104-
timeout-minutes: 30
104+
timeout-minutes: 40
105105
strategy:
106106
matrix:
107107
python-version: [3.6.9, 3.7, 3.8]
@@ -138,7 +138,7 @@ jobs:
138138
test_linux_cutting_edge:
139139
# Tests on cutting edge, i.e. latest Ray master, latest XGBoost master
140140
runs-on: ubuntu-latest
141-
timeout-minutes: 30
141+
timeout-minutes: 40
142142
strategy:
143143
matrix:
144144
python-version: [3.6.9, 3.7, 3.8]
@@ -194,7 +194,7 @@ jobs:
194194
test_linux_xgboost_legacy:
195195
# Tests on XGBoost 0.90 and latest Ray release
196196
runs-on: ubuntu-latest
197-
timeout-minutes: 28
197+
timeout-minutes: 38
198198
strategy:
199199
matrix:
200200
python-version: [3.6.9]

xgboost_ray/main.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,8 @@ def train(self, rabit_args: List[str], return_bst: bool,
501501

502502
if "nthread" not in local_params and "n_jobs" not in local_params:
503503
if num_threads > 0:
504-
local_params["num_threads"] = num_threads
504+
local_params["nthread"] = num_threads
505+
local_params["n_jobs"] = num_threads
505506
else:
506507
local_params["nthread"] = sum(
507508
num
@@ -838,14 +839,16 @@ def _train(params: Dict,
838839
# Un-schedule possible scheduled restarts
839840
_training_state.restart_training_at = None
840841

841-
if "nthread" in params:
842-
if params["nthread"] > cpus_per_actor:
842+
if "nthread" in params or "n_jobs" in params:
843+
if ("nthread" in params and params["nthread"] > cpus_per_actor) or (
844+
"n_jobs" in params and params["n_jobs"] > cpus_per_actor):
843845
raise ValueError(
844846
"Specified number of threads greater than number of CPUs. "
845847
"\nFIX THIS by passing a lower value for the `nthread` "
846848
"parameter or a higher number for `cpus_per_actor`.")
847849
else:
848850
params["nthread"] = cpus_per_actor
851+
params["n_jobs"] = cpus_per_actor
849852

850853
# This is a callback that handles actor failures.
851854
# We identify the rank of the failed actor, add this to a set of

0 commit comments

Comments
 (0)