Skip to content

Commit 1bf3899

Browse files
trivialfishcho3
authored andcommitted
Fix dask ip resolution. (#6475)
This adopts the solution used in dask/dask-xgboost#40 which employs the get_host_ip from dmlc-core tracker.
1 parent c39f6b2 commit 1bf3899

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

python-package/xgboost/dask.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
3434
from .core import _deprecate_positional_args
3535
from .training import train as worker_train
36-
from .tracker import RabitTracker
36+
from .tracker import RabitTracker, get_host_ip
3737
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
3838
from .sklearn import xgboost_model_doc
3939

@@ -70,8 +70,7 @@
7070
def _start_tracker(n_workers):
7171
"""Start Rabit tracker """
7272
env = {'DMLC_NUM_WORKER': n_workers}
73-
import socket
74-
host = socket.gethostbyname(socket.gethostname())
73+
host = get_host_ip('auto')
7574
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
7675
env.update(rabit_context.slave_envs())
7776

python-package/xgboost/tracker.py

+22
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,28 @@ def get_some_ip(host):
5252
return socket.getaddrinfo(host, None)[0][4][0]
5353

5454

55+
def get_host_ip(hostIP=None):
56+
if hostIP is None or hostIP == 'auto':
57+
hostIP = 'ip'
58+
59+
if hostIP == 'dns':
60+
hostIP = socket.getfqdn()
61+
elif hostIP == 'ip':
62+
from socket import gaierror
63+
try:
64+
hostIP = socket.gethostbyname(socket.getfqdn())
65+
except gaierror:
66+
logging.warning(
67+
'gethostbyname(socket.getfqdn()) failed... trying on hostname()')
68+
hostIP = socket.gethostbyname(socket.gethostname())
69+
if hostIP.startswith("127."):
70+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
71+
# doesn't have to be reachable
72+
s.connect(('10.255.255.255', 1))
73+
hostIP = s.getsockname()[0]
74+
return hostIP
75+
76+
5577
def get_family(addr):
5678
return socket.getaddrinfo(addr, None)[0][0]
5779

0 commit comments

Comments
 (0)