5
5
import math
6
6
import argparse
7
7
from benchmarks .communication .constants import *
8
+ from deepspeed .accelerator import get_accelerator
8
9
9
10
global dist
10
11
@@ -14,7 +15,7 @@ def init_torch_distributed(backend):
14
15
import torch .distributed as dist
15
16
torch .distributed .init_process_group (backend )
16
17
local_rank = int (os .environ ['LOCAL_RANK' ])
17
- torch . cuda .set_device (local_rank )
18
+ get_accelerator () .set_device (local_rank )
18
19
19
20
20
21
def init_deepspeed_comm (backend ):
@@ -23,7 +24,7 @@ def init_deepspeed_comm(backend):
23
24
import deepspeed .comm as dist
24
25
deepspeed .init_distributed (dist_backend = backend )
25
26
local_rank = int (os .environ ['LOCAL_RANK' ])
26
- torch . cuda .set_device (local_rank )
27
+ get_accelerator () .set_device (local_rank )
27
28
28
29
29
30
def init_processes (local_rank , args ):
@@ -101,14 +102,13 @@ def get_metric_strings(args, tput, busbw, duration):
101
102
102
103
103
104
def sync_all ():
104
- torch . cuda .synchronize ()
105
+ get_accelerator () .synchronize ()
105
106
dist .barrier ()
106
107
107
108
108
109
def max_numel (comm_op , dtype , mem_factor , local_rank , args ):
109
110
dtype_size = _element_size (dtype )
110
- max_memory_per_gpu = torch .cuda .get_device_properties (
111
- local_rank ).total_memory * mem_factor
111
+ max_memory_per_gpu = get_accelerator ().total_memory (local_rank ) * mem_factor
112
112
if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast' :
113
113
elements_per_gpu = int (max_memory_per_gpu // dtype_size )
114
114
elif comm_op == 'all_gather' :
@@ -185,7 +185,8 @@ def benchmark_parser():
185
185
parser .add_argument ("--backend" ,
186
186
type = str ,
187
187
default = DEFAULT_BACKEND ,
188
- choices = ['nccl' ],
188
+ choices = ['nccl' ,
189
+ 'ccl' ],
189
190
help = 'Communication library to use' )
190
191
parser .add_argument ("--dist" ,
191
192
type = str ,
0 commit comments