Skip to content

Commit 0acf7e9

Browse files
delocktjruwasejeffra
authored
[RFC] add device abstraction to allow other device than CUDA be used (#2221)
Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Jeff Rasley <[email protected]>
1 parent 80d8fcb commit 0acf7e9

File tree

67 files changed

+709
-389
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+709
-389
lines changed

benchmarks/communication/all_gather.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from benchmarks.communication.utils import *
44
from benchmarks.communication.constants import *
5+
from deepspeed.accelerator import get_accelerator
56

67
import time
78

@@ -85,16 +86,20 @@ def run_all_gather(local_rank, args):
8586
try:
8687
mat = torch.ones(world_size,
8788
M,
88-
dtype=getattr(torch,
89-
args.dtype)).cuda(local_rank)
89+
dtype=getattr(
90+
torch,
91+
args.dtype)).to(
92+
get_accelerator().device_name(local_rank))
9093
sync_all()
9194
input = ((mat.mul_(float(global_rank))).view(-1))
9295
# Delete original mat to avoid OOM
9396
del mat
94-
torch.cuda.empty_cache()
97+
get_accelerator().empty_cache()
9598
output = torch.zeros(input.nelement() * world_size,
96-
dtype=getattr(torch,
97-
args.dtype)).cuda(local_rank)
99+
dtype=getattr(
100+
torch,
101+
args.dtype)).to(
102+
get_accelerator().device_name(local_rank))
98103
except RuntimeError as e:
99104
if 'out of memory' in str(e):
100105
if dist.get_rank() == 0:
@@ -123,15 +128,17 @@ def run_all_gather(local_rank, args):
123128
try:
124129
mat = torch.ones(elements_per_gpu,
125130
dtype=getattr(torch,
126-
args.dtype)).cuda(local_rank)
131+
args.dtype)).to(
132+
get_accelerator().device_name(local_rank))
127133
# multiply each GPU's tensor by the rank to ease debugging
128134
input = ((mat.mul_(float(global_rank))).view(-1))
129135
# Delete original mat to avoid OOM
130136
del mat
131-
torch.cuda.empty_cache()
132-
output = torch.zeros(elements_per_gpu * world_size,
133-
dtype=getattr(torch,
134-
args.dtype)).cuda(local_rank)
137+
get_accelerator().empty_cache()
138+
output = torch.zeros(
139+
elements_per_gpu * world_size,
140+
dtype=getattr(torch,
141+
args.dtype)).to(get_accelerator().device_name(local_rank))
135142
except RuntimeError as e:
136143
if 'out of memory' in str(e):
137144
if dist.get_rank() == 0:

benchmarks/communication/all_reduce.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from benchmarks.communication.utils import *
44
from benchmarks.communication.constants import *
5+
from deepspeed.accelerator import get_accelerator
56

67
import time
78

@@ -64,8 +65,10 @@ def run_all_reduce(local_rank, args):
6465
try:
6566
mat = torch.ones(world_size,
6667
M,
67-
dtype=getattr(torch,
68-
args.dtype)).cuda(local_rank)
68+
dtype=getattr(
69+
torch,
70+
args.dtype)).to(
71+
get_accelerator().device_name(local_rank))
6972
sync_all()
7073
input = ((mat.mul_(float(global_rank))).view(-1))
7174
except RuntimeError as e:
@@ -88,7 +91,8 @@ def run_all_reduce(local_rank, args):
8891
try:
8992
mat = torch.ones(elements_per_gpu,
9093
dtype=getattr(torch,
91-
args.dtype)).cuda(local_rank)
94+
args.dtype)).to(
95+
get_accelerator().device_name(local_rank))
9296
input = ((mat.mul_(float(global_rank))).view(-1))
9397
except RuntimeError as e:
9498
if 'out of memory' in str(e):

benchmarks/communication/all_to_all.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from benchmarks.communication.utils import *
44
from benchmarks.communication.constants import *
5+
from deepspeed.accelerator import get_accelerator
56

67
import time
78

@@ -63,8 +64,10 @@ def run_all_to_all(local_rank, args):
6364
try:
6465
mat = torch.ones(world_size,
6566
M,
66-
dtype=getattr(torch,
67-
args.dtype)).cuda(local_rank)
67+
dtype=getattr(
68+
torch,
69+
args.dtype)).to(
70+
get_accelerator().device_name(local_rank))
6871
assert mat.numel() % world_size == 0, f"tensor cannot be divided in {world_size} chunks"
6972
sync_all()
7073
input = ((mat.mul_(float(global_rank))).view(-1))
@@ -88,15 +91,17 @@ def run_all_to_all(local_rank, args):
8891
try:
8992
mat = torch.ones(elements_per_gpu,
9093
dtype=getattr(torch,
91-
args.dtype)).cuda(local_rank)
94+
args.dtype)).to(
95+
get_accelerator().device_name(local_rank))
9296
assert mat.numel() % world_size == 0, f"tensor with {mat.numel()} elements cannot be divided in {world_size} chunks"
9397
input = ((mat.mul_(float(global_rank))).view(-1))
9498
# Delete original mat to avoid OOM
9599
del mat
96-
torch.cuda.empty_cache()
97-
output = torch.zeros(elements_per_gpu,
98-
dtype=getattr(torch,
99-
args.dtype)).cuda(local_rank)
100+
get_accelerator().empty_cache()
101+
output = torch.zeros(
102+
elements_per_gpu,
103+
dtype=getattr(torch,
104+
args.dtype)).to(get_accelerator().device_name(local_rank))
100105
except RuntimeError as e:
101106
if 'out of memory' in str(e):
102107
if dist.get_rank() == 0:

benchmarks/communication/broadcast.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from benchmarks.communication.utils import *
55
from benchmarks.communication.constants import *
6+
from deepspeed.accelerator import get_accelerator
67

78
import time
89

@@ -65,8 +66,10 @@ def run_broadcast(local_rank, args):
6566
try:
6667
mat = torch.ones(world_size,
6768
M,
68-
dtype=getattr(torch,
69-
args.dtype)).cuda(local_rank)
69+
dtype=getattr(
70+
torch,
71+
args.dtype)).to(
72+
get_accelerator().device_name(local_rank))
7073
sync_all()
7174
input = ((mat.mul_(float(global_rank))).view(-1))
7275
except RuntimeError as e:
@@ -89,7 +92,8 @@ def run_broadcast(local_rank, args):
8992
try:
9093
mat = torch.ones(elements_per_gpu,
9194
dtype=getattr(torch,
92-
args.dtype)).cuda(local_rank)
95+
args.dtype)).to(
96+
get_accelerator().device_name(local_rank))
9397
input = ((mat.mul_(float(global_rank))).view(-1))
9498
except RuntimeError as e:
9599
if 'out of memory' in str(e):

benchmarks/communication/constants.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
'''Copyright The Microsoft DeepSpeed Team'''
2+
from deepspeed.accelerator import get_accelerator
23

34
DEFAULT_WARMUPS = 5
45
DEFAULT_TRIALS = 50
56
DEFAULT_TYPE = 'float'
6-
DEFAULT_BACKEND = 'nccl'
7+
DEFAULT_BACKEND = get_accelerator().communication_backend_name()
78
DEFAULT_UNIT = 'Gbps'
89
DEFAULT_DIST = 'deepspeed'
910
DEFAULT_MAXSIZE = 24

benchmarks/communication/pt2pt.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from benchmarks.communication.utils import *
44
from benchmarks.communication.constants import *
5+
from deepspeed.accelerator import get_accelerator
56

67
import time
78

@@ -83,8 +84,10 @@ def run_pt2pt(local_rank, args):
8384
try:
8485
mat = torch.ones(world_size,
8586
M,
86-
dtype=getattr(torch,
87-
args.dtype)).cuda(local_rank)
87+
dtype=getattr(
88+
torch,
89+
args.dtype)).to(
90+
get_accelerator().device_name(local_rank))
8891
sync_all()
8992
input = ((mat.mul_(float(global_rank))).view(-1))
9093
except RuntimeError as e:
@@ -107,7 +110,8 @@ def run_pt2pt(local_rank, args):
107110
try:
108111
mat = torch.ones(elements_per_gpu,
109112
dtype=getattr(torch,
110-
args.dtype)).cuda(local_rank)
113+
args.dtype)).to(
114+
get_accelerator().device_name(local_rank))
111115
input = ((mat.mul_(float(global_rank))).view(-1))
112116
except RuntimeError as e:
113117
if 'out of memory' in str(e):

benchmarks/communication/utils.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import argparse
77
from benchmarks.communication.constants import *
8+
from deepspeed.accelerator import get_accelerator
89

910
global dist
1011

@@ -14,7 +15,7 @@ def init_torch_distributed(backend):
1415
import torch.distributed as dist
1516
torch.distributed.init_process_group(backend)
1617
local_rank = int(os.environ['LOCAL_RANK'])
17-
torch.cuda.set_device(local_rank)
18+
get_accelerator().set_device(local_rank)
1819

1920

2021
def init_deepspeed_comm(backend):
@@ -23,7 +24,7 @@ def init_deepspeed_comm(backend):
2324
import deepspeed.comm as dist
2425
deepspeed.init_distributed(dist_backend=backend)
2526
local_rank = int(os.environ['LOCAL_RANK'])
26-
torch.cuda.set_device(local_rank)
27+
get_accelerator().set_device(local_rank)
2728

2829

2930
def init_processes(local_rank, args):
@@ -101,14 +102,13 @@ def get_metric_strings(args, tput, busbw, duration):
101102

102103

103104
def sync_all():
104-
torch.cuda.synchronize()
105+
get_accelerator().synchronize()
105106
dist.barrier()
106107

107108

108109
def max_numel(comm_op, dtype, mem_factor, local_rank, args):
109110
dtype_size = _element_size(dtype)
110-
max_memory_per_gpu = torch.cuda.get_device_properties(
111-
local_rank).total_memory * mem_factor
111+
max_memory_per_gpu = get_accelerator().total_memory(local_rank) * mem_factor
112112
if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast':
113113
elements_per_gpu = int(max_memory_per_gpu // dtype_size)
114114
elif comm_op == 'all_gather':
@@ -185,7 +185,8 @@ def benchmark_parser():
185185
parser.add_argument("--backend",
186186
type=str,
187187
default=DEFAULT_BACKEND,
188-
choices=['nccl'],
188+
choices=['nccl',
189+
'ccl'],
189190
help='Communication library to use')
190191
parser.add_argument("--dist",
191192
type=str,

benchmarks/inference/bert-bench.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import deepspeed
66
import argparse
77
from transformers import pipeline
8+
from deepspeed.accelerator import get_accelerator
89

910
parser = argparse.ArgumentParser()
1011
parser.add_argument("--model", "-m", type=str, help="hf model name")
@@ -46,7 +47,7 @@ def print_latency(latency_set, title, warmup=3):
4647
print("\t999 Latency: {0:8.2f} ms".format(p999 * 1000))
4748

4849

49-
deepspeed.init_distributed("nccl")
50+
deepspeed.init_distributed()
5051

5152
print(args.model, args.max_tokens, args.dtype)
5253

@@ -75,10 +76,10 @@ def print_latency(latency_set, title, warmup=3):
7576
times = []
7677
mtimes = []
7778
for i in range(args.trials):
78-
torch.cuda.synchronize()
79+
get_accelerator().synchronize()
7980
start = time.time()
8081
r = pipe(f"Hello I'm a {mask} model")
81-
torch.cuda.synchronize()
82+
get_accelerator().synchronize()
8283
end = time.time()
8384
responses.append(r)
8485
times.append((end - start))

benchmarks/inference/gpt-bench.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import deepspeed
77
import argparse
88
from transformers import pipeline
9+
from deepspeed.accelerator import get_accelerator
910

1011
parser = argparse.ArgumentParser()
1112
parser.add_argument("--model", "-m", type=str, help="hf model name")
@@ -63,7 +64,7 @@ def print_latency(latency_set, title, warmup=3):
6364
print("\t999 Latency: {0:8.2f} ms".format(p999 * 1000))
6465

6566

66-
deepspeed.init_distributed("nccl")
67+
deepspeed.init_distributed()
6768

6869
if args.local_rank == 0:
6970
print("BENCHMARK SETTINGS:")
@@ -102,10 +103,10 @@ def print_latency(latency_set, title, warmup=3):
102103
times = []
103104
mtimes = []
104105
for i in range(args.trials):
105-
torch.cuda.synchronize()
106+
get_accelerator().synchronize()
106107
start = time.time()
107108
r = pipe("DeepSpeed is", do_sample=False, max_new_tokens=args.max_tokens)
108-
torch.cuda.synchronize()
109+
get_accelerator().synchronize()
109110
end = time.time()
110111
responses.append(r)
111112
times.append(end - start) # / (args.max_tokens - 3))

deepspeed/module_inject/containers/base.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig
8+
from deepspeed.accelerator import get_accelerator
89

910

1011
class BaseConvolutionContainer(ABC):
@@ -216,12 +217,14 @@ def copy_data_to_new_module(self):
216217
self.module.mlp.attn_nb = self.attn_nb
217218
else:
218219
self.module.mlp.attn_nw.data.copy_(
219-
self.attn_nw.to(torch.cuda.current_device()))
220+
self.attn_nw.to(get_accelerator().current_device_name()))
220221
self.module.mlp.attn_nb.data.copy_(
221-
self.attn_nb.to(torch.cuda.current_device()))
222+
self.attn_nb.to(get_accelerator().current_device_name()))
222223

223-
self.module.norm_w.data.copy_(self.input_nw.to(torch.cuda.current_device()))
224-
self.module.norm_b.data.copy_(self.input_nb.to(torch.cuda.current_device()))
224+
self.module.norm_w.data.copy_(
225+
self.input_nw.to(get_accelerator().current_device_name()))
226+
self.module.norm_b.data.copy_(
227+
self.input_nb.to(get_accelerator().current_device_name()))
225228

226229
def transpose(self):
227230
self.transpose_attention()
@@ -241,5 +244,5 @@ def transpose_impl(self, data):
241244
data = data.contiguous()
242245
data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
243246
data = data.reshape(data.shape[-1], data.shape[-2])
244-
data.to(torch.cuda.current_device())
247+
data.to(get_accelerator().current_device_name())
245248
return data

0 commit comments

Comments
 (0)