Skip to content

Commit 77b3259

Browse files
committed
args.pipeline_mode=pipe to use torch.distributed.pipeline.sync.Pipe
1 parent ea14d26 commit 77b3259

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed

examples/BERT/cross_lingual_mlm_task.py

+48-16
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn as nn
1212
import torch.optim as optim
1313
from torch.distributed.optim import DistributedOptimizer
14+
from torch.distributed.rpc import RRef
1415
from torch.nn.utils.rnn import pad_sequence
1516
from torch.utils.data import DataLoader
1617

@@ -21,6 +22,7 @@
2122
from torchtext.experimental.transforms import sentencepiece_tokenizer
2223
from transforms import PretrainedSPVocab
2324
from torchtext.experimental.models.utils import count_model_param
25+
from torch.distributed.pipeline.sync import Pipe
2426

2527

2628
def collate_batch(batch_data, args, mask_id, pad_id, text_transform):
@@ -58,7 +60,7 @@ def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, devi
5860
return total_loss / (len(data_source) - 1) # Set batch # to 1 for inference
5961

6062

61-
def local_step(model, data, targets, criterion, optimizer, ntokens):
63+
def local_step(model, data, targets, criterion, optimizer, ntokens, args):
6264
optimizer.zero_grad()
6365
output = model(data)
6466
loss = criterion(output.view(-1, ntokens), targets.view(-1))
@@ -69,7 +71,18 @@ def local_step(model, data, targets, criterion, optimizer, ntokens):
6971
return res
7072

7173

72-
def dist_step(model, data, targets, criterion, optimizer, ntokens):
74+
def pipe_step(model, data, targets, criterion, optimizer, ntokens, args):
75+
optimizer.zero_grad()
76+
output = model(data).local_value() # Because torch.distributed.pipeline.sync.Pipe.forward returns RRef
77+
loss = criterion(output.view(-1, ntokens), targets.view(-1))
78+
loss.backward()
79+
res = loss.item()
80+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
81+
optimizer.step()
82+
return res
83+
84+
85+
def rpc_step(model, data, targets, criterion, optimizer, ntokens, args):
7386
with dist_autograd.context() as context_id:
7487
output = model(data)
7588
loss = criterion(output.view(-1, ntokens), targets.view(-1))
@@ -91,7 +104,7 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
91104
for batch, (data, targets) in enumerate(dataloader):
92105
data = data.to(devices[0])
93106
targets = targets.to(devices[-1])
94-
loss = step_impl(model, data, targets, criterion, optimizer, ntokens)
107+
loss = step_impl(model, data, targets, criterion, optimizer, ntokens, args)
95108

96109
total_loss += loss
97110
if batch % args.log_interval == 0 and batch > 0:
@@ -171,12 +184,19 @@ def text_transform(x: str) -> List:
171184
print("Allocating memory")
172185
if args.pipeline_mode == 'sp':
173186
model = SingleProcessPipeline(shards, devices)
174-
175187
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
176188
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75)
177-
else:
189+
elif args.pipeline_mode == 'pipe':
190+
model = Pipe(SingleProcessPipeline(shards, devices, to_device=False), chunks=args.batch_size // args.split_size)
191+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
192+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75)
193+
elif args.pipeline_mode == 'cpu' or args.pipeline_mode == 'cuda':
178194
workers = [f"worker{i+1}" for i in range(len(devices))]
179-
model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=(RemoteBaseCUDARPC if args.pipeline_mode == 'cuda' else RemoteBaseCPURPC))
195+
if args.pipeline_mode == 'cpu':
196+
impl = RemoteBaseCPURPC
197+
elif args.pipeline_mode == 'cuda':
198+
impl = RemoteBaseCUDARPC
199+
model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=impl)
180200
optimizer = DistributedOptimizer(
181201
optim.Adam,
182202
model.parameter_rrefs(),
@@ -199,16 +219,25 @@ def text_transform(x: str) -> List:
199219

200220
epoch_start_time = time.time()
201221
last_lr = scheduler.get_last_lr()[0] if scheduler is not None else args.lr
222+
223+
if args.pipeline_mode == 'sp':
224+
step = local_step
225+
elif args.pipeline_mode == 'pipe':
226+
step = pipe_step
227+
else:
228+
step = rpc_step
229+
230+
if args.pipeline_mode == 'cpu':
231+
train_devices = ["cpu"] # Because "TensorPipe RPC backend only supports CPU tensors by default, please move your tensors to CPU before sending them over RPC"
232+
else:
233+
train_devices = devices
234+
202235
train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
203-
optimizer, criterion, ntokens, epoch, last_lr, args,
204-
devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"],
205-
local_step if args.pipeline_mode == 'sp' else dist_step)
236+
optimizer, criterion, ntokens, epoch, last_lr, args, train_devices, step)
206237

207238
# Turn on evaluation mode which disables dropout.
208239
model.eval()
209-
val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args,
210-
devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"],
211-
text_transform)
240+
val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args, train_devices, text_transform)
212241
val_loss_log.append(val_loss)
213242
print('-' * 89)
214243
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
@@ -253,7 +282,7 @@ def _forward(x):
253282
print('-' * 89)
254283

255284

256-
def run_worker(rank, args):
285+
def run_worker(rank, world_size, args):
257286
os.environ['MASTER_ADDR'] = 'localhost'
258287
os.environ['MASTER_PORT'] = '29500'
259288
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256)
@@ -265,7 +294,7 @@ def run_worker(rank, args):
265294
rpc.init_rpc(
266295
"master",
267296
rank=rank,
268-
world_size=args.gpus+1,
297+
world_size=world_size,
269298
rpc_backend_options=options
270299
)
271300
run_main(args)
@@ -278,7 +307,7 @@ def run_worker(rank, args):
278307
rpc.init_rpc(
279308
f"worker{rank}",
280309
rank=rank,
281-
world_size=args.gpus+1,
310+
world_size=world_size,
282311
rpc_backend_options=options
283312
)
284313
pass
@@ -337,5 +366,8 @@ def run_worker(rank, args):
337366

338367
if args.pipeline_mode == 'sp':
339368
run_main(args)
369+
elif args.pipeline_mode == 'pipe':
370+
# Because torch.distributed.pipeline.sync.Pipe.forward returns RRef and requires RPC
371+
mp.spawn(run_worker, args=(1, args), nprocs=1, join=True)
340372
else:
341-
mp.spawn(run_worker, args=(args,), nprocs=args.gpus+1, join=True)
373+
mp.spawn(run_worker, args=(args.gpus+1, args), nprocs=args.gpus+1, join=True)

examples/BERT/pipeline.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import threading
66
import concurrent.futures
77

8+
89
class ToDevice(nn.Module):
910
def __init__(self, device):
1011
super().__init__()
@@ -15,19 +16,17 @@ def forward(self, x):
1516

1617

1718
class SingleProcessPipeline(nn.Sequential):
18-
def __init__(self, shards, devices):
19+
def __init__(self, shards, devices, to_device=True):
1920
super().__init__()
2021
assert len(shards) == len(devices)
2122
self.devices = devices
22-
self.seq = nn.Sequential()
23-
2423
with concurrent.futures.ThreadPoolExecutor() as executor:
2524
concurrent.futures.wait([executor.submit(lambda s, d: s.to(d), shards[i], devices[i]) for i in range(len(shards))])
2625

2726
for i, shard in enumerate(shards):
28-
self.seq.add_module(f'Shard({devices[i]})', shard)
29-
if i != len(shards)-1:
30-
self.seq.add_module(f'ToDevice({devices[i+1]})', ToDevice(devices[i+1]))
27+
self.add_module(f'Shard({devices[i]})', shard)
28+
if to_device and i != len(shards)-1:
29+
self.add_module(f'ToDevice({devices[i+1]})', ToDevice(devices[i+1]))
3130

3231

3332
class RemoteBaseCPURPC(nn.Module):
@@ -87,4 +86,3 @@ def parameter_rrefs(self):
8786
for shard in self.shards:
8887
remote_params.extend(shard.remote().parameter_rrefs().to_here())
8988
return remote_params
90-

0 commit comments

Comments
 (0)