Skip to content

Commit

Permalink
SingleProcessPipeline and RPCPipeline with CPU RPC and CUDA RPC (#1177)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbelevich authored Feb 22, 2021
1 parent 2b9a280 commit ea14d26
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 146 deletions.
133 changes: 101 additions & 32 deletions examples/BERT/cross_lingual_mlm_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from torch.utils.data import DataLoader

from data import CC100
from dist_model import DistCrossLingualMLMTask
from model import CrossLingualMLMTask
from pipeline import SingleProcessPipeline, RPCPipeline, RemoteBaseCPURPC, RemoteBaseCUDARPC
from shard_model import XLMRModelShards, MLMShards
from torchtext.experimental.transforms import sentencepiece_tokenizer
from transforms import PretrainedSPVocab
from torchtext.experimental.models.utils import count_model_param


def collate_batch(batch_data, args, mask_id, pad_id, text_transform):
Expand All @@ -43,27 +45,28 @@ def collate_batch(batch_data, args, mask_id, pad_id, text_transform):
return batch_data, targets


def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, device, text_transform):
def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, devices, text_transform):
total_loss = 0.
dataloader = DataLoader(data_source, batch_size=1, # Set batch # to 1 for inference
shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, pad_id, text_transform))
with torch.no_grad():
for batch, (data, targets) in enumerate(dataloader):
data = data.to(device)
targets = targets.to(device)
data = data.to(devices[0])
targets = targets.to(devices[-1])
output = model(data)
total_loss += criterion(output.view(-1, ntokens), targets.view(-1)).item()
return total_loss / (len(data_source) - 1) # Set batch # to 1 for inference


def step(model, data, targets, criterion, optimizer, ntokens):
def local_step(model, data, targets, criterion, optimizer, ntokens):
optimizer.zero_grad()
output = model(data)
loss = criterion(output.view(-1, ntokens), targets.view(-1))
loss.backward()
res = loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
return loss
return res


def dist_step(model, data, targets, criterion, optimizer, ntokens):
Expand All @@ -73,11 +76,11 @@ def dist_step(model, data, targets, criterion, optimizer, ntokens):
dist_autograd.backward(context_id, [loss])
# torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step(context_id)
return loss
return loss.item()


def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
optimizer, criterion, ntokens, epoch, last_lr, args, device, step_impl):
optimizer, criterion, ntokens, epoch, last_lr, args, devices, step_impl):
model.train()
total_loss = 0.
start_time = time.time()
Expand All @@ -86,9 +89,11 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, pad_id, text_transform))

for batch, (data, targets) in enumerate(dataloader):
loss = step_impl(model, data.to(device), targets.to(device), criterion, optimizer, ntokens)
data = data.to(devices[0])
targets = targets.to(devices[-1])
loss = step_impl(model, data, targets, criterion, optimizer, ntokens)

total_loss += loss.item()
total_loss += loss
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss / args.log_interval
elapsed = time.time() - start_time
Expand Down Expand Up @@ -116,47 +121,101 @@ def text_transform(x: str) -> List:
pad_id = vocab(['pad'])[0]
ntokens = len(vocab)

if not args.dist:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CrossLingualMLMTask(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
model = model.to(device)
xlmr = XLMRModelShards(ntokens, args.emsize, args.nhead, args.nhid, args.dropout)
mlm = MLMShards(ntokens, args.emsize)
devices = [f"cuda:{i}" for i in range(args.gpus)] if torch.cuda.is_available() else ["cpu"]

if len(devices) == 1:
# In case of one device combine all layers into a single nn.Sequential
shards = [nn.Sequential(
xlmr.xlmr_embed(),
xlmr.encoder_layers(args.nlayers),
mlm.mlm()
)]
elif len(devices) == 2:
# In case of two devices split the model right in the middle and
# put the embeddings and half of encoders to the first shard and
# another half of encoders and mlm head to the second.
assert args.nlayers % 2 == 0
shards = [
nn.Sequential(
xlmr.xlmr_embed(),
xlmr.encoder_layers(args.nlayers // 2)
),
nn.Sequential(
xlmr.encoder_layers(args.nlayers // 2),
mlm.mlm()
)
]
else:
# In case of more that 2 devices put the embeddings and mlm head
# to the first and the last shard and split the encoders to equal
# parts among the rest of the shards
encoder_gpus = (args.gpus - 2)
assert args.nlayers % encoder_gpus == 0
encoders_per_gpu = args.nlayers // encoder_gpus
shards = [
xlmr.xlmr_embed(),
*[xlmr.encoder_layers(encoders_per_gpu) for _ in range(encoder_gpus)],
mlm.mlm()
]

print('Shards parameters:')
total = 0
for i, shard in enumerate(shards):
params = count_model_param(shard)
total += params
print(f'shard{i} = {int(params)}M')
print(f'total = {int(total)}M')

print("Allocating memory")
if args.pipeline_mode == 'sp':
model = SingleProcessPipeline(shards, devices)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75)
else:
device = "cpu"
model = DistCrossLingualMLMTask(args.split_size, ["worker1", "worker2"], ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
workers = [f"worker{i+1}" for i in range(len(devices))]
model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=(RemoteBaseCUDARPC if args.pipeline_mode == 'cuda' else RemoteBaseCPURPC))
optimizer = DistributedOptimizer(
optim.Adam,
model.parameter_rrefs(),
lr=args.lr,
)
scheduler = None

print("Memory allocated")
# input("Memory allocated, check nvidia-smi for memory consumption")

criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
best_val_loss = None
train_loss_log, val_loss_log = [], []

for epoch in range(1, args.epochs + 1):
train_data = CC100('/datasets01/cc100/031720/', {'*.txt'}, start_line=args.start_line, num_lines=args.num_lines)
train_data = CC100(args.cc100_path, {'*.txt'}, start_line=args.start_line, num_lines=args.num_lines)
from torchtext.datasets import WikiText2
val_data, = WikiText2(data_select='valid')
val_data = WikiText2(split='valid')
val_data = [(17, item) for item in val_data if item != ' \n'] # english language type is 17 in CC100 dataset

epoch_start_time = time.time()
last_lr = scheduler.get_last_lr()[0] if scheduler is not None else args.lr
train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
optimizer, criterion, ntokens, epoch, last_lr, args, device, step if not args.dist else dist_step)
optimizer, criterion, ntokens, epoch, last_lr, args,
devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"],
local_step if args.pipeline_mode == 'sp' else dist_step)

# Turn on evaluation mode which disables dropout.
model.eval()
val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args, device, text_transform)
val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args,
devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"],
text_transform)
val_loss_log.append(val_loss)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-' * 89)
if not args.dist and not best_val_loss or val_loss < best_val_loss:
if args.pipeline_mode == 'sp' and not best_val_loss or val_loss < best_val_loss:
with open(args.save, 'wb') as f:
torch.save(model, f)
best_val_loss = val_loss
Expand All @@ -173,7 +232,7 @@ def text_transform(x: str) -> List:
def text_transform(x: str) -> List:
return ref_model.encode(x).tolist()
model = ref_model.model.encoder
model = model.to(device)
model = model.to(devices[0])
# Turn on evaluation mode which disables dropout.
model.eval()
# from fairseq XLM-R model
Expand All @@ -187,7 +246,7 @@ def _forward(x):
return nn_model(x.transpose(0, 1))[0].transpose(0, 1)
return _forward
val_loss = evaluate(val_data, model_forward(model), mask_id, pad_id, ref_ntokens,
criterion, args, device, text_transform)
criterion, args, devices[0], text_transform)
print('-' * 89)
print('| reference model | valid loss {:5.2f} | '
'valid ppl {:8.2f}'.format(val_loss, math.exp(val_loss)))
Expand All @@ -200,18 +259,26 @@ def run_worker(rank, args):
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256)

if rank == 0:
if args.pipeline_mode == 'cuda':
for i in range(args.gpus):
options.set_device_map("worker" + str(i + 1), {i:i})
rpc.init_rpc(
"master",
rank=rank,
world_size=args.world_size,
world_size=args.gpus+1,
rpc_backend_options=options
)
run_main(args)
else:
if args.pipeline_mode == 'cuda':
if rank == 1:
options.set_device_map("master", {0:0})
else:
options.set_device_map("worker" + str(rank - 1), {(rank - 1):(rank - 2)})
rpc.init_rpc(
f"worker{rank}",
rank=rank,
world_size=args.world_size,
world_size=args.gpus+1,
rpc_backend_options=options
)
pass
Expand Down Expand Up @@ -258,15 +325,17 @@ def run_worker(rank, args):
help='path to load the reference model for evaluation')
parser.add_argument('--mask_frac', type=float, default=0.15,
help='the fraction of masked tokens')
parser.add_argument('--dist', action='store_true',
help='run distributed version')
parser.add_argument('--world_size', type=int, default=3,
help='world_size')
parser.add_argument('--cc100_path', type=str, default='/datasets01/cc100/031720/',
help='path to cc100')
parser.add_argument('--gpus', type=int, default=1,
help='number of GPUs to use')
parser.add_argument('--pipeline_mode', type=str, default='sp',
help='pipeline mode, `cpu` for CPU RPC, `cuda` for CUDA RPC, `sp` for single process pipeline')
parser.add_argument('--split_size', type=int, default=8,
help='split the input batch into micro-batches')
args = parser.parse_args()

if args.dist:
mp.spawn(run_worker, args=(args,), nprocs=args.world_size, join=True)
else:
if args.pipeline_mode == 'sp':
run_main(args)
else:
mp.spawn(run_worker, args=(args,), nprocs=args.gpus+1, join=True)
114 changes: 0 additions & 114 deletions examples/BERT/dist_model.py

This file was deleted.

Loading

0 comments on commit ea14d26

Please sign in to comment.