Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

args.pipeline_mode=pipe to use torch.distributed.pipeline.sync.Pipe #1210

Draft
wants to merge 1 commit into
base: xlmr_mlm
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 57 additions & 24 deletions examples/BERT/cross_lingual_mlm_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.pipeline.sync import Pipe
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from data import CC100
from model import CrossLingualMLMTask
from pipeline import SingleProcessPipeline, RPCPipeline, RemoteBaseCPURPC, RemoteBaseCUDARPC
from shard_model import XLMRModelShards, MLMShards
from torchtext.experimental.models.utils import count_model_param
from torchtext.experimental.transforms import sentencepiece_tokenizer
from transforms import PretrainedSPVocab
from torchtext.experimental.models.utils import count_model_param


def collate_batch(batch_data, args, mask_id, pad_id, text_transform):
Expand Down Expand Up @@ -58,7 +58,7 @@ def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, devi
return total_loss / (len(data_source) - 1) # Set batch # to 1 for inference


def local_step(model, data, targets, criterion, optimizer, ntokens):
def local_step(model, data, targets, criterion, optimizer, ntokens, args):
optimizer.zero_grad()
output = model(data)
loss = criterion(output.view(-1, ntokens), targets.view(-1))
Expand All @@ -69,7 +69,18 @@ def local_step(model, data, targets, criterion, optimizer, ntokens):
return res


def dist_step(model, data, targets, criterion, optimizer, ntokens):
def pipe_step(model, data, targets, criterion, optimizer, ntokens, args):
optimizer.zero_grad()
output = model(data).local_value() # Because torch.distributed.pipeline.sync.Pipe.forward returns RRef
loss = criterion(output.view(-1, ntokens), targets.view(-1))
loss.backward()
res = loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
return res


def rpc_step(model, data, targets, criterion, optimizer, ntokens, args):
with dist_autograd.context() as context_id:
output = model(data)
loss = criterion(output.view(-1, ntokens), targets.view(-1))
Expand All @@ -91,19 +102,19 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
for batch, (data, targets) in enumerate(dataloader):
data = data.to(devices[0])
targets = targets.to(devices[-1])
loss = step_impl(model, data, targets, criterion, optimizer, ntokens)
loss = step_impl(model, data, targets, criterion, optimizer, ntokens, args)

total_loss += loss
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss / args.log_interval
elapsed = time.time() - start_time
train_loss_log[-1] = cur_loss
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch,
len(train_data) // args.batch_size,
last_lr,
elapsed * 1000 / args.log_interval,
cur_loss, math.exp(cur_loss)))
'loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch,
len(train_data) // args.batch_size,
last_lr,
elapsed * 1000 / args.log_interval,
cur_loss, math.exp(cur_loss)))
total_loss = 0
start_time = time.time()

Expand Down Expand Up @@ -171,18 +182,27 @@ def text_transform(x: str) -> List:
print("Allocating memory")
if args.pipeline_mode == 'sp':
model = SingleProcessPipeline(shards, devices)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75)
else:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.75)
elif args.pipeline_mode == 'pipe':
model = Pipe(SingleProcessPipeline(shards, devices, to_device=False), chunks=args.batch_size // args.split_size)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.75)
elif args.pipeline_mode == 'cpu' or args.pipeline_mode == 'cuda':
workers = [f"worker{i+1}" for i in range(len(devices))]
model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=(RemoteBaseCUDARPC if args.pipeline_mode == 'cuda' else RemoteBaseCPURPC))
if args.pipeline_mode == 'cpu':
impl = RemoteBaseCPURPC
elif args.pipeline_mode == 'cuda':
impl = RemoteBaseCUDARPC
model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=impl)
optimizer = DistributedOptimizer(
optim.Adam,
model.parameter_rrefs(),
lr=args.lr,
)
scheduler = None
else:
raise ValueError("Unsupported pipeline_mode")

print("Memory allocated")
# input("Memory allocated, check nvidia-smi for memory consumption")
Expand All @@ -199,16 +219,26 @@ def text_transform(x: str) -> List:

epoch_start_time = time.time()
last_lr = scheduler.get_last_lr()[0] if scheduler is not None else args.lr

if args.pipeline_mode == 'sp':
step = local_step
elif args.pipeline_mode == 'pipe':
step = pipe_step
else:
step = rpc_step

if args.pipeline_mode == 'cpu':
train_devices = ["cpu"] # Because "TensorPipe RPC backend only supports CPU tensors by default,
# please move your tensors to CPU before sending them over RPC"
else:
train_devices = devices

train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
optimizer, criterion, ntokens, epoch, last_lr, args,
devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"],
local_step if args.pipeline_mode == 'sp' else dist_step)
optimizer, criterion, ntokens, epoch, last_lr, args, train_devices, step)

# Turn on evaluation mode which disables dropout.
model.eval()
val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args,
devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"],
text_transform)
val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args, train_devices, text_transform)
val_loss_log.append(val_loss)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
Expand Down Expand Up @@ -253,7 +283,7 @@ def _forward(x):
print('-' * 89)


def run_worker(rank, args):
def run_worker(rank, world_size, args):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256)
Expand All @@ -265,7 +295,7 @@ def run_worker(rank, args):
rpc.init_rpc(
"master",
rank=rank,
world_size=args.gpus+1,
world_size=world_size,
rpc_backend_options=options
)
run_main(args)
Expand All @@ -278,7 +308,7 @@ def run_worker(rank, args):
rpc.init_rpc(
f"worker{rank}",
rank=rank,
world_size=args.gpus+1,
world_size=world_size,
rpc_backend_options=options
)
pass
Expand Down Expand Up @@ -337,5 +367,8 @@ def run_worker(rank, args):

if args.pipeline_mode == 'sp':
run_main(args)
elif args.pipeline_mode == 'pipe':
# Because torch.distributed.pipeline.sync.Pipe.forward returns RRef and requires RPC
mp.spawn(run_worker, args=(1, args), nprocs=1, join=True)
else:
mp.spawn(run_worker, args=(args,), nprocs=args.gpus+1, join=True)
mp.spawn(run_worker, args=(args.gpus+1, args), nprocs=args.gpus+1, join=True)
25 changes: 13 additions & 12 deletions examples/BERT/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import concurrent.futures
import threading

import torch
import torch.nn as nn
import torch.distributed.rpc as rpc
import torch.nn as nn
from torch.distributed.rpc import RRef
import threading
import concurrent.futures


class ToDevice(nn.Module):
def __init__(self, device):
Expand All @@ -15,19 +17,18 @@ def forward(self, x):


class SingleProcessPipeline(nn.Sequential):
def __init__(self, shards, devices):
def __init__(self, shards, devices, to_device=True):
super().__init__()
assert len(shards) == len(devices)
self.devices = devices
self.seq = nn.Sequential()

with concurrent.futures.ThreadPoolExecutor() as executor:
concurrent.futures.wait([executor.submit(lambda s, d: s.to(d), shards[i], devices[i]) for i in range(len(shards))])
concurrent.futures.wait(
[executor.submit(lambda s, d: s.to(d), shards[i], devices[i]) for i in range(len(shards))])

for i, shard in enumerate(shards):
self.seq.add_module(f'Shard({devices[i]})', shard)
if i != len(shards)-1:
self.seq.add_module(f'ToDevice({devices[i+1]})', ToDevice(devices[i+1]))
self.add_module(f'Shard({devices[i]})', shard)
if to_device and i != len(shards) - 1:
self.add_module(f'ToDevice({devices[i + 1]})', ToDevice(devices[i + 1]))


class RemoteBaseCPURPC(nn.Module):
Expand Down Expand Up @@ -66,7 +67,8 @@ class RPCPipeline(nn.Module):
def __init__(self, shards, devices, workers, remote_base_class=RemoteBaseCPURPC, split_size=1):
super().__init__()
self.split_size = split_size
self.shards = [rpc.remote(worker, remote_base_class, args=(shard, device)) for worker, shard, device in zip(workers, shards, devices)]
self.shards = [rpc.remote(worker, remote_base_class, args=(shard, device)) for worker, shard, device in
zip(workers, shards, devices)]

def forward(self, xs):
# Split the input batch xs into micro-batches, and collect async RPC
Expand All @@ -87,4 +89,3 @@ def parameter_rrefs(self):
for shard in self.shards:
remote_params.extend(shard.remote().parameter_rrefs().to_here())
return remote_params

3 changes: 2 additions & 1 deletion examples/BERT/shard_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch.nn as nn

from model import TransformerEncoderLayer, XLMREmbedding


class XLMRModelShards():
def __init__(self, ntoken, ninp, nhead, nhid, dropout=0.5):
self.ntoken = ntoken
Expand All @@ -15,7 +17,6 @@ def encoder_layers(self, nlayers):
return nn.TransformerEncoder(self.encoder_layer, nlayers)



class MLMShards():
def __init__(self, ntoken, ninp):
self.ntoken = ntoken
Expand Down