Skip to content

Commit b57429e

Browse files
szmigaczko3n1g
authored andcommitted
ADLR/megatron-lm!1757 - Parse LOCAL_RANK in arguments.py, get device from LOCAL_RANK, and set device_id for init_process_group
1 parent 16eea87 commit b57429e

2 files changed

Lines changed: 18 additions & 15 deletions

File tree

megatron/training/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1459,7 +1459,7 @@ def _add_distributed_args(parser):
14591459
default=False, help='If set, use custom-built ring exchange '
14601460
'for p2p communications. Note that this option will require '
14611461
'a custom built image that support ring-exchange p2p.')
1462-
group.add_argument('--local_rank', type=int, default=None,
1462+
group.add_argument('--local-rank', type=int, default=int(os.getenv('LOCAL_RANK', '0')),
14631463
help='local rank passed from distributed launcher.')
14641464
group.add_argument('--lazy-mpu-init', type=bool, required=False,
14651465
help='If set to True, initialize_megatron() '

megatron/training/initialize.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
import random
66
import os
7+
import packaging
8+
import packaging.version
79
import time
810

911
import numpy as np
@@ -233,21 +235,22 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
233235
print("> initializing torch distributed ...", flush=True)
234236
# Manually set the device ids.
235237
if device_count > 0:
236-
device = args.rank % device_count
237-
if args.local_rank is not None:
238-
assert (
239-
args.local_rank == device
240-
), "expected local-rank to be the same as rank % device-count."
241-
else:
242-
args.local_rank = device
243-
torch.cuda.set_device(device)
238+
torch.cuda.set_device(args.local_rank)
239+
device_id = torch.device(f'cuda:{args.local_rank}')
240+
else:
241+
device_id = None
242+
244243
# Call the init process
245-
torch.distributed.init_process_group(
246-
backend=args.distributed_backend,
247-
world_size=args.world_size,
248-
rank=args.rank,
249-
timeout=timedelta(minutes=args.distributed_timeout_minutes),
250-
)
244+
init_process_group_kwargs = {
245+
'backend' : args.distributed_backend,
246+
'world_size': args.world_size,
247+
'rank': args.rank,
248+
'timeout': timedelta(minutes=args.distributed_timeout_minutes),
249+
}
250+
if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"):
251+
init_process_group_kwargs['device_id'] = device_id
252+
253+
torch.distributed.init_process_group(**init_process_group_kwargs)
251254

252255
# Set the tensor model-parallel, pipeline model-parallel, and
253256
# data-parallel communicators.

0 commit comments

Comments
 (0)