Skip to content

Commit ea14d26

Browse files
authored
SingleProcessPipeline and RPCPipeline with CPU RPC and CUDA RPC (#1177)
1 parent 2b9a280 commit ea14d26

File tree

4 files changed

+221
-146
lines changed

4 files changed

+221
-146
lines changed

examples/BERT/cross_lingual_mlm_task.py

+101-32
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from torch.utils.data import DataLoader
1616

1717
from data import CC100
18-
from dist_model import DistCrossLingualMLMTask
1918
from model import CrossLingualMLMTask
19+
from pipeline import SingleProcessPipeline, RPCPipeline, RemoteBaseCPURPC, RemoteBaseCUDARPC
20+
from shard_model import XLMRModelShards, MLMShards
2021
from torchtext.experimental.transforms import sentencepiece_tokenizer
2122
from transforms import PretrainedSPVocab
23+
from torchtext.experimental.models.utils import count_model_param
2224

2325

2426
def collate_batch(batch_data, args, mask_id, pad_id, text_transform):
@@ -43,27 +45,28 @@ def collate_batch(batch_data, args, mask_id, pad_id, text_transform):
4345
return batch_data, targets
4446

4547

46-
def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, device, text_transform):
48+
def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, devices, text_transform):
4749
total_loss = 0.
4850
dataloader = DataLoader(data_source, batch_size=1, # Set batch # to 1 for inference
4951
shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, pad_id, text_transform))
5052
with torch.no_grad():
5153
for batch, (data, targets) in enumerate(dataloader):
52-
data = data.to(device)
53-
targets = targets.to(device)
54+
data = data.to(devices[0])
55+
targets = targets.to(devices[-1])
5456
output = model(data)
5557
total_loss += criterion(output.view(-1, ntokens), targets.view(-1)).item()
5658
return total_loss / (len(data_source) - 1) # Set batch # to 1 for inference
5759

5860

59-
def step(model, data, targets, criterion, optimizer, ntokens):
61+
def local_step(model, data, targets, criterion, optimizer, ntokens):
6062
optimizer.zero_grad()
6163
output = model(data)
6264
loss = criterion(output.view(-1, ntokens), targets.view(-1))
6365
loss.backward()
66+
res = loss.item()
6467
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
6568
optimizer.step()
66-
return loss
69+
return res
6770

6871

6972
def dist_step(model, data, targets, criterion, optimizer, ntokens):
@@ -73,11 +76,11 @@ def dist_step(model, data, targets, criterion, optimizer, ntokens):
7376
dist_autograd.backward(context_id, [loss])
7477
# torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
7578
optimizer.step(context_id)
76-
return loss
79+
return loss.item()
7780

7881

7982
def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
80-
optimizer, criterion, ntokens, epoch, last_lr, args, device, step_impl):
83+
optimizer, criterion, ntokens, epoch, last_lr, args, devices, step_impl):
8184
model.train()
8285
total_loss = 0.
8386
start_time = time.time()
@@ -86,9 +89,11 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
8689
shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, pad_id, text_transform))
8790

8891
for batch, (data, targets) in enumerate(dataloader):
89-
loss = step_impl(model, data.to(device), targets.to(device), criterion, optimizer, ntokens)
92+
data = data.to(devices[0])
93+
targets = targets.to(devices[-1])
94+
loss = step_impl(model, data, targets, criterion, optimizer, ntokens)
9095

91-
total_loss += loss.item()
96+
total_loss += loss
9297
if batch % args.log_interval == 0 and batch > 0:
9398
cur_loss = total_loss / args.log_interval
9499
elapsed = time.time() - start_time
@@ -116,47 +121,101 @@ def text_transform(x: str) -> List:
116121
pad_id = vocab(['pad'])[0]
117122
ntokens = len(vocab)
118123

119-
if not args.dist:
120-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121-
model = CrossLingualMLMTask(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
122-
model = model.to(device)
124+
xlmr = XLMRModelShards(ntokens, args.emsize, args.nhead, args.nhid, args.dropout)
125+
mlm = MLMShards(ntokens, args.emsize)
126+
devices = [f"cuda:{i}" for i in range(args.gpus)] if torch.cuda.is_available() else ["cpu"]
127+
128+
if len(devices) == 1:
129+
# In case of one device combine all layers into a single nn.Sequential
130+
shards = [nn.Sequential(
131+
xlmr.xlmr_embed(),
132+
xlmr.encoder_layers(args.nlayers),
133+
mlm.mlm()
134+
)]
135+
elif len(devices) == 2:
136+
# In case of two devices split the model right in the middle and
137+
# put the embeddings and half of encoders to the first shard and
138+
# another half of encoders and mlm head to the second.
139+
assert args.nlayers % 2 == 0
140+
shards = [
141+
nn.Sequential(
142+
xlmr.xlmr_embed(),
143+
xlmr.encoder_layers(args.nlayers // 2)
144+
),
145+
nn.Sequential(
146+
xlmr.encoder_layers(args.nlayers // 2),
147+
mlm.mlm()
148+
)
149+
]
150+
else:
151+
# In case of more that 2 devices put the embeddings and mlm head
152+
# to the first and the last shard and split the encoders to equal
153+
# parts among the rest of the shards
154+
encoder_gpus = (args.gpus - 2)
155+
assert args.nlayers % encoder_gpus == 0
156+
encoders_per_gpu = args.nlayers // encoder_gpus
157+
shards = [
158+
xlmr.xlmr_embed(),
159+
*[xlmr.encoder_layers(encoders_per_gpu) for _ in range(encoder_gpus)],
160+
mlm.mlm()
161+
]
162+
163+
print('Shards parameters:')
164+
total = 0
165+
for i, shard in enumerate(shards):
166+
params = count_model_param(shard)
167+
total += params
168+
print(f'shard{i} = {int(params)}M')
169+
print(f'total = {int(total)}M')
170+
171+
print("Allocating memory")
172+
if args.pipeline_mode == 'sp':
173+
model = SingleProcessPipeline(shards, devices)
174+
123175
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
124176
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75)
125177
else:
126-
device = "cpu"
127-
model = DistCrossLingualMLMTask(args.split_size, ["worker1", "worker2"], ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
178+
workers = [f"worker{i+1}" for i in range(len(devices))]
179+
model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=(RemoteBaseCUDARPC if args.pipeline_mode == 'cuda' else RemoteBaseCPURPC))
128180
optimizer = DistributedOptimizer(
129181
optim.Adam,
130182
model.parameter_rrefs(),
131183
lr=args.lr,
132184
)
133185
scheduler = None
134186

187+
print("Memory allocated")
188+
# input("Memory allocated, check nvidia-smi for memory consumption")
189+
135190
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
136191
best_val_loss = None
137192
train_loss_log, val_loss_log = [], []
138193

139194
for epoch in range(1, args.epochs + 1):
140-
train_data = CC100('/datasets01/cc100/031720/', {'*.txt'}, start_line=args.start_line, num_lines=args.num_lines)
195+
train_data = CC100(args.cc100_path, {'*.txt'}, start_line=args.start_line, num_lines=args.num_lines)
141196
from torchtext.datasets import WikiText2
142-
val_data, = WikiText2(data_select='valid')
197+
val_data = WikiText2(split='valid')
143198
val_data = [(17, item) for item in val_data if item != ' \n'] # english language type is 17 in CC100 dataset
144199

145200
epoch_start_time = time.time()
146201
last_lr = scheduler.get_last_lr()[0] if scheduler is not None else args.lr
147202
train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
148-
optimizer, criterion, ntokens, epoch, last_lr, args, device, step if not args.dist else dist_step)
203+
optimizer, criterion, ntokens, epoch, last_lr, args,
204+
devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"],
205+
local_step if args.pipeline_mode == 'sp' else dist_step)
149206

150207
# Turn on evaluation mode which disables dropout.
151208
model.eval()
152-
val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args, device, text_transform)
209+
val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args,
210+
devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"],
211+
text_transform)
153212
val_loss_log.append(val_loss)
154213
print('-' * 89)
155214
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
156215
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
157216
val_loss, math.exp(val_loss)))
158217
print('-' * 89)
159-
if not args.dist and not best_val_loss or val_loss < best_val_loss:
218+
if args.pipeline_mode == 'sp' and not best_val_loss or val_loss < best_val_loss:
160219
with open(args.save, 'wb') as f:
161220
torch.save(model, f)
162221
best_val_loss = val_loss
@@ -173,7 +232,7 @@ def text_transform(x: str) -> List:
173232
def text_transform(x: str) -> List:
174233
return ref_model.encode(x).tolist()
175234
model = ref_model.model.encoder
176-
model = model.to(device)
235+
model = model.to(devices[0])
177236
# Turn on evaluation mode which disables dropout.
178237
model.eval()
179238
# from fairseq XLM-R model
@@ -187,7 +246,7 @@ def _forward(x):
187246
return nn_model(x.transpose(0, 1))[0].transpose(0, 1)
188247
return _forward
189248
val_loss = evaluate(val_data, model_forward(model), mask_id, pad_id, ref_ntokens,
190-
criterion, args, device, text_transform)
249+
criterion, args, devices[0], text_transform)
191250
print('-' * 89)
192251
print('| reference model | valid loss {:5.2f} | '
193252
'valid ppl {:8.2f}'.format(val_loss, math.exp(val_loss)))
@@ -200,18 +259,26 @@ def run_worker(rank, args):
200259
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256)
201260

202261
if rank == 0:
262+
if args.pipeline_mode == 'cuda':
263+
for i in range(args.gpus):
264+
options.set_device_map("worker" + str(i + 1), {i:i})
203265
rpc.init_rpc(
204266
"master",
205267
rank=rank,
206-
world_size=args.world_size,
268+
world_size=args.gpus+1,
207269
rpc_backend_options=options
208270
)
209271
run_main(args)
210272
else:
273+
if args.pipeline_mode == 'cuda':
274+
if rank == 1:
275+
options.set_device_map("master", {0:0})
276+
else:
277+
options.set_device_map("worker" + str(rank - 1), {(rank - 1):(rank - 2)})
211278
rpc.init_rpc(
212279
f"worker{rank}",
213280
rank=rank,
214-
world_size=args.world_size,
281+
world_size=args.gpus+1,
215282
rpc_backend_options=options
216283
)
217284
pass
@@ -258,15 +325,17 @@ def run_worker(rank, args):
258325
help='path to load the reference model for evaluation')
259326
parser.add_argument('--mask_frac', type=float, default=0.15,
260327
help='the fraction of masked tokens')
261-
parser.add_argument('--dist', action='store_true',
262-
help='run distributed version')
263-
parser.add_argument('--world_size', type=int, default=3,
264-
help='world_size')
328+
parser.add_argument('--cc100_path', type=str, default='/datasets01/cc100/031720/',
329+
help='path to cc100')
330+
parser.add_argument('--gpus', type=int, default=1,
331+
help='number of GPUs to use')
332+
parser.add_argument('--pipeline_mode', type=str, default='sp',
333+
help='pipeline mode, `cpu` for CPU RPC, `cuda` for CUDA RPC, `sp` for single process pipeline')
265334
parser.add_argument('--split_size', type=int, default=8,
266335
help='split the input batch into micro-batches')
267336
args = parser.parse_args()
268337

269-
if args.dist:
270-
mp.spawn(run_worker, args=(args,), nprocs=args.world_size, join=True)
271-
else:
338+
if args.pipeline_mode == 'sp':
272339
run_main(args)
340+
else:
341+
mp.spawn(run_worker, args=(args,), nprocs=args.gpus+1, join=True)

examples/BERT/dist_model.py

-114
This file was deleted.

0 commit comments

Comments
 (0)