Skip to content

Commit 213f0ef

Browse files
authored
Merge pull request #1133 from pytorch/xlmr_mlm_two_shards
DistCrossLingualMLMTask with two shards
2 parents 7b6eef4 + 456b521 commit 213f0ef

File tree

3 files changed

+219
-31
lines changed

3 files changed

+219
-31
lines changed

examples/BERT/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ To run the workflow with 3000 lines from each of the 100 languages (CC-100 datas
151151

152152
python cross_lingual_mlm_task.py --num_lines 3000
153153

154+
To run the distributed training use '--dist' flag, to specify world size use '--world_size=N', the default world size is 3 for one master and 2 worker nodes.
155+
156+
python cross_lingual_mlm_task.py --num_lines 3000 --dist
157+
154158
To Run the reference XLM-R model from fairseq, download and unzip the pretrained model from [link](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz).
155159

156160
python cross_lingual_mlm_task.py --eval_ref ./xlmr.large

examples/BERT/cross_lingual_mlm_task.py

+101-31
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
import argparse
2-
import time
32
import math
3+
import os
4+
import time
5+
from typing import List
6+
47
import torch
8+
import torch.distributed.autograd as dist_autograd
9+
import torch.distributed.rpc as rpc
10+
import torch.multiprocessing as mp
511
import torch.nn as nn
12+
import torch.optim as optim
13+
from torch.distributed.optim import DistributedOptimizer
14+
from torch.nn.utils.rnn import pad_sequence
15+
from torch.utils.data import DataLoader
16+
617
from data import CC100
18+
from dist_model import DistCrossLingualMLMTask
719
from model import CrossLingualMLMTask
8-
from torch.utils.data import DataLoader
920
from torchtext.experimental.transforms import sentencepiece_tokenizer
1021
from transforms import PretrainedSPVocab
11-
from torch.nn.utils.rnn import pad_sequence
12-
from typing import List
1322

1423

1524
def collate_batch(batch_data, args, mask_id, pad_id, text_transform):
@@ -47,8 +56,28 @@ def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, devi
4756
return total_loss / (len(data_source) - 1) # Set batch # to 1 for inference
4857

4958

59+
def step(model, data, targets, criterion, optimizer, ntokens):
60+
optimizer.zero_grad()
61+
output = model(data)
62+
loss = criterion(output.view(-1, ntokens), targets.view(-1))
63+
loss.backward()
64+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
65+
optimizer.step()
66+
return loss
67+
68+
69+
def dist_step(model, data, targets, criterion, optimizer, ntokens):
70+
with dist_autograd.context() as context_id:
71+
output = model(data)
72+
loss = criterion(output.view(-1, ntokens), targets.view(-1))
73+
dist_autograd.backward(context_id, [loss])
74+
# torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
75+
optimizer.step(context_id)
76+
return loss
77+
78+
5079
def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
51-
optimizer, criterion, ntokens, epoch, scheduler, args, device, rank=None):
80+
optimizer, criterion, ntokens, epoch, last_lr, args, device, step_impl):
5281
model.train()
5382
total_loss = 0.
5483
start_time = time.time()
@@ -57,33 +86,25 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
5786
shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, pad_id, text_transform))
5887

5988
for batch, (data, targets) in enumerate(dataloader):
60-
optimizer.zero_grad()
61-
data = data.to(device)
62-
targets = targets.to(device)
63-
output = model(data)
64-
loss = criterion(output.view(-1, ntokens), targets.view(-1))
65-
loss.backward()
66-
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
67-
optimizer.step()
89+
loss = step_impl(model, data.to(device), targets.to(device), criterion, optimizer, ntokens)
90+
6891
total_loss += loss.item()
6992
if batch % args.log_interval == 0 and batch > 0:
7093
cur_loss = total_loss / args.log_interval
7194
elapsed = time.time() - start_time
72-
if (rank is None) or rank == 0:
73-
train_loss_log[-1] = cur_loss
74-
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
75-
'loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch,
76-
len(train_data) // args.batch_size,
77-
scheduler.get_last_lr()[0],
78-
elapsed * 1000 / args.log_interval,
79-
cur_loss, math.exp(cur_loss)))
95+
train_loss_log[-1] = cur_loss
96+
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
97+
'loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch,
98+
len(train_data) // args.batch_size,
99+
last_lr,
100+
elapsed * 1000 / args.log_interval,
101+
cur_loss, math.exp(cur_loss)))
80102
total_loss = 0
81103
start_time = time.time()
82104

83105

84-
def run_main(args, rank=None):
106+
def run_main(args):
85107
torch.manual_seed(args.seed)
86-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87108

88109
# Set up tokenizer and vocab
89110
tokenizer = sentencepiece_tokenizer(args.spm_path)
@@ -95,11 +116,23 @@ def text_transform(x: str) -> List:
95116
pad_id = vocab(['pad'])[0]
96117
ntokens = len(vocab)
97118

98-
model = CrossLingualMLMTask(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
99-
model = model.to(device)
119+
if not args.dist:
120+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121+
model = CrossLingualMLMTask(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
122+
model = model.to(device)
123+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
124+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75)
125+
else:
126+
device = "cpu"
127+
model = DistCrossLingualMLMTask(args.split_size, ["worker1", "worker2"], ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
128+
optimizer = DistributedOptimizer(
129+
optim.Adam,
130+
model.parameter_rrefs(),
131+
lr=args.lr,
132+
)
133+
scheduler = None
134+
100135
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
101-
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
102-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75)
103136
best_val_loss = None
104137
train_loss_log, val_loss_log = [], []
105138

@@ -110,8 +143,9 @@ def text_transform(x: str) -> List:
110143
val_data = [(17, item) for item in val_data if item != ' \n'] # english language type is 17 in CC100 dataset
111144

112145
epoch_start_time = time.time()
146+
last_lr = scheduler.get_last_lr()[0] if scheduler is not None else args.lr
113147
train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
114-
optimizer, criterion, ntokens, epoch, scheduler, args, device, rank)
148+
optimizer, criterion, ntokens, epoch, last_lr, args, device, step if not args.dist else dist_step)
115149

116150
# Turn on evaluation mode which disables dropout.
117151
model.eval()
@@ -122,12 +156,13 @@ def text_transform(x: str) -> List:
122156
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
123157
val_loss, math.exp(val_loss)))
124158
print('-' * 89)
125-
if not best_val_loss or val_loss < best_val_loss:
159+
if not args.dist and not best_val_loss or val_loss < best_val_loss:
126160
with open(args.save, 'wb') as f:
127161
torch.save(model, f)
128162
best_val_loss = val_loss
129163
else:
130-
scheduler.step()
164+
if scheduler is not None:
165+
scheduler.step()
131166

132167
# Run reference XLM-R model from fairseq
133168
if args.eval_ref != 'None':
@@ -159,6 +194,32 @@ def _forward(x):
159194
print('-' * 89)
160195

161196

197+
def run_worker(rank, args):
198+
os.environ['MASTER_ADDR'] = 'localhost'
199+
os.environ['MASTER_PORT'] = '29500'
200+
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256)
201+
202+
if rank == 0:
203+
rpc.init_rpc(
204+
"master",
205+
rank=rank,
206+
world_size=args.world_size,
207+
rpc_backend_options=options
208+
)
209+
run_main(args)
210+
else:
211+
rpc.init_rpc(
212+
f"worker{rank}",
213+
rank=rank,
214+
world_size=args.world_size,
215+
rpc_backend_options=options
216+
)
217+
pass
218+
219+
# block until all rpcs finish
220+
rpc.shutdown()
221+
222+
162223
if __name__ == "__main__":
163224
parser = argparse.ArgumentParser(description='PyTorch Cross-lingual XLM MLM')
164225
parser.add_argument('--emsize', type=int, default=768,
@@ -197,6 +258,15 @@ def _forward(x):
197258
help='path to load the reference model for evaluation')
198259
parser.add_argument('--mask_frac', type=float, default=0.15,
199260
help='the fraction of masked tokens')
261+
parser.add_argument('--dist', action='store_true',
262+
help='run distributed version')
263+
parser.add_argument('--world_size', type=int, default=3,
264+
help='world_size')
265+
parser.add_argument('--split_size', type=int, default=8,
266+
help='split the input batch into micro-batches')
200267
args = parser.parse_args()
201268

202-
run_main(args)
269+
if args.dist:
270+
mp.spawn(run_worker, args=(args,), nprocs=args.world_size, join=True)
271+
else:
272+
run_main(args)

examples/BERT/dist_model.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import threading
2+
3+
import torch
4+
import torch.distributed.rpc as rpc
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from torch.distributed.rpc import RRef
8+
from torch.nn import Linear, LayerNorm
9+
10+
from model import XLMREmbedding, TransformerEncoderLayer, TransformerEncoder
11+
12+
13+
def get_cuda_if_available(i):
14+
assert i >= 0
15+
if torch.cuda.is_available():
16+
return f"cuda:{min(i, torch.cuda.device_count() - 1)}"
17+
else:
18+
return "cpu"
19+
20+
21+
class CrossLingualMLMTaskBase(nn.Module):
22+
def __init__(self, device):
23+
super(CrossLingualMLMTaskBase, self).__init__()
24+
self.device = device
25+
self._lock = threading.Lock()
26+
27+
def forward(self, x_rref):
28+
x = x_rref.to_here().to(self.device)
29+
with self._lock:
30+
out = self._forward(x)
31+
return out.cpu()
32+
33+
def parameter_rrefs(self):
34+
r"""
35+
Create one RRef for each parameter in the given local module, and return a
36+
list of RRefs.
37+
"""
38+
return [RRef(p) for p in self.parameters()]
39+
40+
41+
class CrossLingualMLMTaskShard1(CrossLingualMLMTaskBase):
42+
def __init__(self, device, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
43+
super(CrossLingualMLMTaskShard1, self).__init__(device)
44+
self.xlmr_embed = XLMREmbedding(ntoken, ninp, dropout).to(device)
45+
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
46+
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers // 2).to(device)
47+
48+
def _forward(self, src):
49+
output = self.xlmr_embed(src)
50+
output = self.transformer_encoder(output)
51+
return output
52+
53+
54+
class CrossLingualMLMTaskShard2(CrossLingualMLMTaskBase):
55+
def __init__(self, device, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
56+
super(CrossLingualMLMTaskShard2, self).__init__(device)
57+
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
58+
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers // 2).to(device)
59+
self.mlm_span = Linear(ninp, ninp).to(device)
60+
self.activation = F.gelu
61+
self.norm_layer = LayerNorm(ninp, eps=1e-12).to(device)
62+
self.mlm_head = Linear(ninp, ntoken).to(device)
63+
64+
def _forward(self, src):
65+
output = self.transformer_encoder(src)
66+
output = self.mlm_span(output)
67+
output = self.activation(output)
68+
output = self.norm_layer(output)
69+
output = self.mlm_head(output)
70+
return output
71+
72+
73+
class DistCrossLingualMLMTask(nn.Module):
74+
"""Two shards CrossLingualMLMTask"""
75+
76+
def __init__(self, split_size, workers, *args, **kwargs):
77+
super(DistCrossLingualMLMTask, self).__init__()
78+
79+
self.split_size = split_size
80+
81+
# Put the first part of the ResNet50 on workers[0]
82+
self.p1_rref = rpc.remote(
83+
workers[0],
84+
CrossLingualMLMTaskShard1,
85+
args=(get_cuda_if_available(0),) + args,
86+
kwargs=kwargs
87+
)
88+
89+
# Put the second part of the ResNet50 on workers[1]
90+
self.p2_rref = rpc.remote(
91+
workers[1],
92+
CrossLingualMLMTaskShard2,
93+
args=(get_cuda_if_available(1),) + args,
94+
kwargs=kwargs
95+
)
96+
97+
def forward(self, xs):
98+
# Split the input batch xs into micro-batches, and collect async RPC
99+
# futures into a list
100+
out_futures = []
101+
for x in iter(xs.split(self.split_size, dim=0)):
102+
x_rref = RRef(x)
103+
y_rref = self.p1_rref.remote().forward(x_rref)
104+
z_fut = self.p2_rref.rpc_async().forward(y_rref)
105+
out_futures.append(z_fut)
106+
107+
# collect and cat all output tensors into one tensor.
108+
return torch.cat(torch.futures.wait_all(out_futures))
109+
110+
def parameter_rrefs(self):
111+
remote_params = []
112+
remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here())
113+
remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here())
114+
return remote_params

0 commit comments

Comments
 (0)