|
4 | 4 | import logging |
5 | 5 | import random |
6 | 6 | import os |
| 7 | +import packaging |
| 8 | +import packaging.version |
7 | 9 | import time |
8 | 10 |
|
9 | 11 | import numpy as np |
@@ -233,21 +235,22 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): |
233 | 235 | print("> initializing torch distributed ...", flush=True) |
234 | 236 | # Manually set the device ids. |
235 | 237 | if device_count > 0: |
236 | | - device = args.rank % device_count |
237 | | - if args.local_rank is not None: |
238 | | - assert ( |
239 | | - args.local_rank == device |
240 | | - ), "expected local-rank to be the same as rank % device-count." |
241 | | - else: |
242 | | - args.local_rank = device |
243 | | - torch.cuda.set_device(device) |
| 238 | + torch.cuda.set_device(args.local_rank) |
| 239 | + device_id = torch.device(f'cuda:{args.local_rank}') |
| 240 | + else: |
| 241 | + device_id = None |
| 242 | + |
244 | 243 | # Call the init process |
245 | | - torch.distributed.init_process_group( |
246 | | - backend=args.distributed_backend, |
247 | | - world_size=args.world_size, |
248 | | - rank=args.rank, |
249 | | - timeout=timedelta(minutes=args.distributed_timeout_minutes), |
250 | | - ) |
| 244 | + init_process_group_kwargs = { |
| 245 | + 'backend' : args.distributed_backend, |
| 246 | + 'world_size': args.world_size, |
| 247 | + 'rank': args.rank, |
| 248 | + 'timeout': timedelta(minutes=args.distributed_timeout_minutes), |
| 249 | + } |
| 250 | + if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"): |
| 251 | + init_process_group_kwargs['device_id'] = device_id |
| 252 | + |
| 253 | + torch.distributed.init_process_group(**init_process_group_kwargs) |
251 | 254 |
|
252 | 255 | # Set the tensor model-parallel, pipeline model-parallel, and |
253 | 256 | # data-parallel communicators. |
|
0 commit comments