Skip to content

Commit db15ef5

Browse files
noabaumajeffra
andauthored
deepspeed.init_distributed() support for TCP protocols (#2905)
Co-authored-by: Jeff Rasley <[email protected]>
1 parent 0acf7e9 commit db15ef5

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

deepspeed/comm/comm.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,9 @@ def init_distributed(dist_backend=None,
603603
timeout=default_pg_timeout,
604604
init_method=None,
605605
dist_init_required=None,
606-
config=None):
606+
config=None,
607+
rank=-1,
608+
world_size=-1):
607609
''' Initialize dist backend, potentially performing MPI discovery if needed
608610
609611
Arguments:
@@ -614,6 +616,8 @@ def init_distributed(dist_backend=None,
614616
timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
615617
init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
616618
config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
619+
rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
620+
world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization.
617621
'''
618622
global cdb
619623

@@ -658,7 +662,7 @@ def init_distributed(dist_backend=None,
658662
'Initializing TorchBackend in DeepSpeed with backend {}'.format(
659663
dist_backend))
660664
# Create a torch backend object, initialize torch distributed, and assign to cdb
661-
cdb = TorchBackend(dist_backend, timeout, init_method)
665+
cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
662666

663667

664668
def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):

deepspeed/comm/torch.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@ class TorchBackend(Backend):
1616
so no need to wrap all the functions. We can keep adding wrappers as
1717
needed.
1818
"""
19-
def __init__(self, backend, timeout, init_method, name='torch'):
19+
def __init__(self,
20+
backend,
21+
timeout,
22+
init_method,
23+
rank=-1,
24+
world_size=-1,
25+
name='torch'):
2026
super(TorchBackend, self).__init__()
2127
self.torch_version_before_18 = older_torch()
2228
self.has_allgather_base = has_allgather_base()
@@ -27,13 +33,15 @@ def __init__(self, backend, timeout, init_method, name='torch'):
2733
# The idea is to fake that dist backend is initialized even when
2834
# it is not so we can run on a single GPU without doing any init_process_group
2935
self.single_gpu_mode = True
30-
self.init_process_group(backend, timeout, init_method)
36+
self.init_process_group(backend, timeout, init_method, rank, world_size)
3137

32-
def init_process_group(self, backend, timeout, init_method):
38+
def init_process_group(self, backend, timeout, init_method, rank, world_size):
3339
if not torch.distributed.is_initialized():
3440
torch.distributed.init_process_group(backend,
3541
timeout=timeout,
36-
init_method=init_method)
42+
init_method=init_method,
43+
rank=rank,
44+
world_size=world_size)
3745
self.using_mpi = torch.distributed.get_backend() == 'mpi'
3846

3947
def all_reduce(self,

0 commit comments

Comments
 (0)