Skip to content

Commit 1efefa7

Browse files
committed
Merge branch 'use_local_rank' into 'main'
Parse LOCAL_RANK in arguments.py, get device from LOCAL_RANK, and set device_id for init_process_group See merge request ADLR/megatron-lm!1757
2 parents 9768756 + b57429e commit 1efefa7

2 files changed

Lines changed: 18 additions & 15 deletions

File tree

megatron/training/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1463,7 +1463,7 @@ def _add_distributed_args(parser):
14631463
default=False, help='If set, use custom-built ring exchange '
14641464
'for p2p communications. Note that this option will require '
14651465
'a custom built image that support ring-exchange p2p.')
1466-
group.add_argument('--local_rank', type=int, default=None,
1466+
group.add_argument('--local-rank', type=int, default=int(os.getenv('LOCAL_RANK', '0')),
14671467
help='local rank passed from distributed launcher.')
14681468
group.add_argument('--lazy-mpu-init', type=bool, required=False,
14691469
help='If set to True, initialize_megatron() '

megatron/training/initialize.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
import random
66
import os
7+
import packaging
8+
import packaging.version
79
import time
810

911
import numpy as np
@@ -233,21 +235,22 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
233235
print("> initializing torch distributed ...", flush=True)
234236
# Manually set the device ids.
235237
if device_count > 0:
236-
device = args.rank % device_count
237-
if args.local_rank is not None:
238-
assert (
239-
args.local_rank == device
240-
), "expected local-rank to be the same as rank % device-count."
241-
else:
242-
args.local_rank = device
243-
torch.cuda.set_device(device)
238+
torch.cuda.set_device(args.local_rank)
239+
device_id = torch.device(f'cuda:{args.local_rank}')
240+
else:
241+
device_id = None
242+
244243
# Call the init process
245-
torch.distributed.init_process_group(
246-
backend=args.distributed_backend,
247-
world_size=args.world_size,
248-
rank=args.rank,
249-
timeout=timedelta(minutes=args.distributed_timeout_minutes),
250-
)
244+
init_process_group_kwargs = {
245+
'backend' : args.distributed_backend,
246+
'world_size': args.world_size,
247+
'rank': args.rank,
248+
'timeout': timedelta(minutes=args.distributed_timeout_minutes),
249+
}
250+
if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"):
251+
init_process_group_kwargs['device_id'] = device_id
252+
253+
torch.distributed.init_process_group(**init_process_group_kwargs)
251254

252255
# Set the tensor model-parallel, pipeline model-parallel, and
253256
# data-parallel communicators.

0 commit comments

Comments
 (0)