From ea14d26b0a7b7c533bc0c109be89d4a080bba46e Mon Sep 17 00:00:00 2001 From: Pavel Belevich Date: Mon, 22 Feb 2021 14:53:37 -0500 Subject: [PATCH] SingleProcessPipeline and RPCPipeline with CPU RPC and CUDA RPC (#1177) --- examples/BERT/cross_lingual_mlm_task.py | 133 ++++++++++++++++++------ examples/BERT/dist_model.py | 114 -------------------- examples/BERT/pipeline.py | 90 ++++++++++++++++ examples/BERT/shard_model.py | 30 ++++++ 4 files changed, 221 insertions(+), 146 deletions(-) delete mode 100644 examples/BERT/dist_model.py create mode 100644 examples/BERT/pipeline.py create mode 100644 examples/BERT/shard_model.py diff --git a/examples/BERT/cross_lingual_mlm_task.py b/examples/BERT/cross_lingual_mlm_task.py index ca448dc568..916b88f93d 100644 --- a/examples/BERT/cross_lingual_mlm_task.py +++ b/examples/BERT/cross_lingual_mlm_task.py @@ -15,10 +15,12 @@ from torch.utils.data import DataLoader from data import CC100 -from dist_model import DistCrossLingualMLMTask from model import CrossLingualMLMTask +from pipeline import SingleProcessPipeline, RPCPipeline, RemoteBaseCPURPC, RemoteBaseCUDARPC +from shard_model import XLMRModelShards, MLMShards from torchtext.experimental.transforms import sentencepiece_tokenizer from transforms import PretrainedSPVocab +from torchtext.experimental.models.utils import count_model_param def collate_batch(batch_data, args, mask_id, pad_id, text_transform): @@ -43,27 +45,28 @@ def collate_batch(batch_data, args, mask_id, pad_id, text_transform): return batch_data, targets -def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, device, text_transform): +def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, devices, text_transform): total_loss = 0. dataloader = DataLoader(data_source, batch_size=1, # Set batch # to 1 for inference shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, pad_id, text_transform)) with torch.no_grad(): for batch, (data, targets) in enumerate(dataloader): - data = data.to(device) - targets = targets.to(device) + data = data.to(devices[0]) + targets = targets.to(devices[-1]) output = model(data) total_loss += criterion(output.view(-1, ntokens), targets.view(-1)).item() return total_loss / (len(data_source) - 1) # Set batch # to 1 for inference -def step(model, data, targets, criterion, optimizer, ntokens): +def local_step(model, data, targets, criterion, optimizer, ntokens): optimizer.zero_grad() output = model(data) loss = criterion(output.view(-1, ntokens), targets.view(-1)) loss.backward() + res = loss.item() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() - return loss + return res def dist_step(model, data, targets, criterion, optimizer, ntokens): @@ -73,11 +76,11 @@ def dist_step(model, data, targets, criterion, optimizer, ntokens): dist_autograd.backward(context_id, [loss]) # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step(context_id) - return loss + return loss.item() def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform, - optimizer, criterion, ntokens, epoch, last_lr, args, device, step_impl): + optimizer, criterion, ntokens, epoch, last_lr, args, devices, step_impl): model.train() total_loss = 0. start_time = time.time() @@ -86,9 +89,11 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform, shuffle=False, collate_fn=lambda b: collate_batch(b, args, mask_id, pad_id, text_transform)) for batch, (data, targets) in enumerate(dataloader): - loss = step_impl(model, data.to(device), targets.to(device), criterion, optimizer, ntokens) + data = data.to(devices[0]) + targets = targets.to(devices[-1]) + loss = step_impl(model, data, targets, criterion, optimizer, ntokens) - total_loss += loss.item() + total_loss += loss if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss / args.log_interval elapsed = time.time() - start_time @@ -116,15 +121,62 @@ def text_transform(x: str) -> List: pad_id = vocab(['pad'])[0] ntokens = len(vocab) - if not args.dist: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = CrossLingualMLMTask(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout) - model = model.to(device) + xlmr = XLMRModelShards(ntokens, args.emsize, args.nhead, args.nhid, args.dropout) + mlm = MLMShards(ntokens, args.emsize) + devices = [f"cuda:{i}" for i in range(args.gpus)] if torch.cuda.is_available() else ["cpu"] + + if len(devices) == 1: + # In case of one device combine all layers into a single nn.Sequential + shards = [nn.Sequential( + xlmr.xlmr_embed(), + xlmr.encoder_layers(args.nlayers), + mlm.mlm() + )] + elif len(devices) == 2: + # In case of two devices split the model right in the middle and + # put the embeddings and half of encoders to the first shard and + # another half of encoders and mlm head to the second. + assert args.nlayers % 2 == 0 + shards = [ + nn.Sequential( + xlmr.xlmr_embed(), + xlmr.encoder_layers(args.nlayers // 2) + ), + nn.Sequential( + xlmr.encoder_layers(args.nlayers // 2), + mlm.mlm() + ) + ] + else: + # In case of more that 2 devices put the embeddings and mlm head + # to the first and the last shard and split the encoders to equal + # parts among the rest of the shards + encoder_gpus = (args.gpus - 2) + assert args.nlayers % encoder_gpus == 0 + encoders_per_gpu = args.nlayers // encoder_gpus + shards = [ + xlmr.xlmr_embed(), + *[xlmr.encoder_layers(encoders_per_gpu) for _ in range(encoder_gpus)], + mlm.mlm() + ] + + print('Shards parameters:') + total = 0 + for i, shard in enumerate(shards): + params = count_model_param(shard) + total += params + print(f'shard{i} = {int(params)}M') + print(f'total = {int(total)}M') + + print("Allocating memory") + if args.pipeline_mode == 'sp': + model = SingleProcessPipeline(shards, devices) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75) else: - device = "cpu" - model = DistCrossLingualMLMTask(args.split_size, ["worker1", "worker2"], ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout) + workers = [f"worker{i+1}" for i in range(len(devices))] + model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=(RemoteBaseCUDARPC if args.pipeline_mode == 'cuda' else RemoteBaseCPURPC)) optimizer = DistributedOptimizer( optim.Adam, model.parameter_rrefs(), @@ -132,31 +184,38 @@ def text_transform(x: str) -> List: ) scheduler = None + print("Memory allocated") + # input("Memory allocated, check nvidia-smi for memory consumption") + criterion = nn.CrossEntropyLoss(ignore_index=pad_id) best_val_loss = None train_loss_log, val_loss_log = [], [] for epoch in range(1, args.epochs + 1): - train_data = CC100('/datasets01/cc100/031720/', {'*.txt'}, start_line=args.start_line, num_lines=args.num_lines) + train_data = CC100(args.cc100_path, {'*.txt'}, start_line=args.start_line, num_lines=args.num_lines) from torchtext.datasets import WikiText2 - val_data, = WikiText2(data_select='valid') + val_data = WikiText2(split='valid') val_data = [(17, item) for item in val_data if item != ' \n'] # english language type is 17 in CC100 dataset epoch_start_time = time.time() last_lr = scheduler.get_last_lr()[0] if scheduler is not None else args.lr train(model, mask_id, pad_id, train_loss_log, train_data, text_transform, - optimizer, criterion, ntokens, epoch, last_lr, args, device, step if not args.dist else dist_step) + optimizer, criterion, ntokens, epoch, last_lr, args, + devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"], + local_step if args.pipeline_mode == 'sp' else dist_step) # Turn on evaluation mode which disables dropout. model.eval() - val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args, device, text_transform) + val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args, + devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"], + text_transform) val_loss_log.append(val_loss) print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss))) print('-' * 89) - if not args.dist and not best_val_loss or val_loss < best_val_loss: + if args.pipeline_mode == 'sp' and not best_val_loss or val_loss < best_val_loss: with open(args.save, 'wb') as f: torch.save(model, f) best_val_loss = val_loss @@ -173,7 +232,7 @@ def text_transform(x: str) -> List: def text_transform(x: str) -> List: return ref_model.encode(x).tolist() model = ref_model.model.encoder - model = model.to(device) + model = model.to(devices[0]) # Turn on evaluation mode which disables dropout. model.eval() # from fairseq XLM-R model @@ -187,7 +246,7 @@ def _forward(x): return nn_model(x.transpose(0, 1))[0].transpose(0, 1) return _forward val_loss = evaluate(val_data, model_forward(model), mask_id, pad_id, ref_ntokens, - criterion, args, device, text_transform) + criterion, args, devices[0], text_transform) print('-' * 89) print('| reference model | valid loss {:5.2f} | ' 'valid ppl {:8.2f}'.format(val_loss, math.exp(val_loss))) @@ -200,18 +259,26 @@ def run_worker(rank, args): options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256) if rank == 0: + if args.pipeline_mode == 'cuda': + for i in range(args.gpus): + options.set_device_map("worker" + str(i + 1), {i:i}) rpc.init_rpc( "master", rank=rank, - world_size=args.world_size, + world_size=args.gpus+1, rpc_backend_options=options ) run_main(args) else: + if args.pipeline_mode == 'cuda': + if rank == 1: + options.set_device_map("master", {0:0}) + else: + options.set_device_map("worker" + str(rank - 1), {(rank - 1):(rank - 2)}) rpc.init_rpc( f"worker{rank}", rank=rank, - world_size=args.world_size, + world_size=args.gpus+1, rpc_backend_options=options ) pass @@ -258,15 +325,17 @@ def run_worker(rank, args): help='path to load the reference model for evaluation') parser.add_argument('--mask_frac', type=float, default=0.15, help='the fraction of masked tokens') - parser.add_argument('--dist', action='store_true', - help='run distributed version') - parser.add_argument('--world_size', type=int, default=3, - help='world_size') + parser.add_argument('--cc100_path', type=str, default='/datasets01/cc100/031720/', + help='path to cc100') + parser.add_argument('--gpus', type=int, default=1, + help='number of GPUs to use') + parser.add_argument('--pipeline_mode', type=str, default='sp', + help='pipeline mode, `cpu` for CPU RPC, `cuda` for CUDA RPC, `sp` for single process pipeline') parser.add_argument('--split_size', type=int, default=8, help='split the input batch into micro-batches') args = parser.parse_args() - if args.dist: - mp.spawn(run_worker, args=(args,), nprocs=args.world_size, join=True) - else: + if args.pipeline_mode == 'sp': run_main(args) + else: + mp.spawn(run_worker, args=(args,), nprocs=args.gpus+1, join=True) diff --git a/examples/BERT/dist_model.py b/examples/BERT/dist_model.py deleted file mode 100644 index d87e1b23be..0000000000 --- a/examples/BERT/dist_model.py +++ /dev/null @@ -1,114 +0,0 @@ -import threading - -import torch -import torch.distributed.rpc as rpc -import torch.nn as nn -import torch.nn.functional as F -from torch.distributed.rpc import RRef -from torch.nn import Linear, LayerNorm - -from model import XLMREmbedding, TransformerEncoderLayer, TransformerEncoder - - -def get_cuda_if_available(i): - assert i >= 0 - if torch.cuda.is_available(): - return f"cuda:{min(i, torch.cuda.device_count() - 1)}" - else: - return "cpu" - - -class CrossLingualMLMTaskBase(nn.Module): - def __init__(self, device): - super(CrossLingualMLMTaskBase, self).__init__() - self.device = device - self._lock = threading.Lock() - - def forward(self, x_rref): - x = x_rref.to_here().to(self.device) - with self._lock: - out = self._forward(x) - return out.cpu() - - def parameter_rrefs(self): - r""" - Create one RRef for each parameter in the given local module, and return a - list of RRefs. - """ - return [RRef(p) for p in self.parameters()] - - -class CrossLingualMLMTaskShard1(CrossLingualMLMTaskBase): - def __init__(self, device, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): - super(CrossLingualMLMTaskShard1, self).__init__(device) - self.xlmr_embed = XLMREmbedding(ntoken, ninp, dropout).to(device) - encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers // 2).to(device) - - def _forward(self, src): - output = self.xlmr_embed(src) - output = self.transformer_encoder(output) - return output - - -class CrossLingualMLMTaskShard2(CrossLingualMLMTaskBase): - def __init__(self, device, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): - super(CrossLingualMLMTaskShard2, self).__init__(device) - encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers // 2).to(device) - self.mlm_span = Linear(ninp, ninp).to(device) - self.activation = F.gelu - self.norm_layer = LayerNorm(ninp, eps=1e-12).to(device) - self.mlm_head = Linear(ninp, ntoken).to(device) - - def _forward(self, src): - output = self.transformer_encoder(src) - output = self.mlm_span(output) - output = self.activation(output) - output = self.norm_layer(output) - output = self.mlm_head(output) - return output - - -class DistCrossLingualMLMTask(nn.Module): - """Two shards CrossLingualMLMTask""" - - def __init__(self, split_size, workers, *args, **kwargs): - super(DistCrossLingualMLMTask, self).__init__() - - self.split_size = split_size - - # Put the first part of the ResNet50 on workers[0] - self.p1_rref = rpc.remote( - workers[0], - CrossLingualMLMTaskShard1, - args=(get_cuda_if_available(0),) + args, - kwargs=kwargs - ) - - # Put the second part of the ResNet50 on workers[1] - self.p2_rref = rpc.remote( - workers[1], - CrossLingualMLMTaskShard2, - args=(get_cuda_if_available(1),) + args, - kwargs=kwargs - ) - - def forward(self, xs): - # Split the input batch xs into micro-batches, and collect async RPC - # futures into a list - out_futures = [] - for x in iter(xs.split(self.split_size, dim=0)): - x_rref = RRef(x) - y_rref = self.p1_rref.remote().forward(x_rref) - z_fut = self.p2_rref.rpc_async().forward(y_rref) - out_futures.append(z_fut) - - # collect and cat all output tensors into one tensor. - return torch.cat(torch.futures.wait_all(out_futures)) - - def parameter_rrefs(self): - remote_params = [] - remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here()) - remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here()) - return remote_params diff --git a/examples/BERT/pipeline.py b/examples/BERT/pipeline.py new file mode 100644 index 0000000000..bc6e8bd7a2 --- /dev/null +++ b/examples/BERT/pipeline.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import torch.distributed.rpc as rpc +from torch.distributed.rpc import RRef +import threading +import concurrent.futures + +class ToDevice(nn.Module): + def __init__(self, device): + super().__init__() + self.device = device + + def forward(self, x): + return x.to(self.device) + + +class SingleProcessPipeline(nn.Sequential): + def __init__(self, shards, devices): + super().__init__() + assert len(shards) == len(devices) + self.devices = devices + self.seq = nn.Sequential() + + with concurrent.futures.ThreadPoolExecutor() as executor: + concurrent.futures.wait([executor.submit(lambda s, d: s.to(d), shards[i], devices[i]) for i in range(len(shards))]) + + for i, shard in enumerate(shards): + self.seq.add_module(f'Shard({devices[i]})', shard) + if i != len(shards)-1: + self.seq.add_module(f'ToDevice({devices[i+1]})', ToDevice(devices[i+1])) + + +class RemoteBaseCPURPC(nn.Module): + def __init__(self, underlying, device): + super().__init__() + self.underlying = underlying.to(device) + self.device = device + self._lock = threading.Lock() + + def forward(self, x_rref): + x = x_rref.to_here().to(self.device) + with self._lock: + out = self.underlying(x) + return out.cpu() + + def parameter_rrefs(self): + return [RRef(p) for p in self.parameters()] + + +class RemoteBaseCUDARPC(nn.Module): + def __init__(self, underlying, device): + super().__init__() + self.underlying = underlying.to(device) + self.device = device + self._lock = threading.Lock() + + def forward(self, x_rref): + with self._lock: + return self.underlying(x_rref.to_here()) + + def parameter_rrefs(self): + return [RRef(p) for p in self.parameters()] + + +class RPCPipeline(nn.Module): + def __init__(self, shards, devices, workers, remote_base_class=RemoteBaseCPURPC, split_size=1): + super().__init__() + self.split_size = split_size + self.shards = [rpc.remote(worker, remote_base_class, args=(shard, device)) for worker, shard, device in zip(workers, shards, devices)] + + def forward(self, xs): + # Split the input batch xs into micro-batches, and collect async RPC + # futures into a list + out_futures = [] + for x in iter(xs.split(self.split_size, dim=0)): + x_rref = RRef(x) + for shard in self.shards[:-1]: + x_rref = shard.remote().forward(x_rref) + z_fut = self.shards[-1].rpc_async().forward(x_rref) + out_futures.append(z_fut) + + # collect and cat all output tensors into one tensor. + return torch.cat(torch.futures.wait_all(out_futures)) + + def parameter_rrefs(self): + remote_params = [] + for shard in self.shards: + remote_params.extend(shard.remote().parameter_rrefs().to_here()) + return remote_params + \ No newline at end of file diff --git a/examples/BERT/shard_model.py b/examples/BERT/shard_model.py new file mode 100644 index 0000000000..b2878dc8f7 --- /dev/null +++ b/examples/BERT/shard_model.py @@ -0,0 +1,30 @@ +import torch.nn as nn +from model import TransformerEncoderLayer, XLMREmbedding + +class XLMRModelShards(): + def __init__(self, ntoken, ninp, nhead, nhid, dropout=0.5): + self.ntoken = ntoken + self.ninp = ninp + self.dropout = dropout + self.encoder_layer = TransformerEncoderLayer(ninp, nhead, nhid, dropout) + + def xlmr_embed(self): + return XLMREmbedding(self.ntoken, self.ninp, self.dropout) + + def encoder_layers(self, nlayers): + return nn.TransformerEncoder(self.encoder_layer, nlayers) + + + +class MLMShards(): + def __init__(self, ntoken, ninp): + self.ntoken = ntoken + self.ninp = ninp + + def mlm(self): + return nn.Sequential( + nn.Linear(self.ninp, self.ninp), + nn.GELU(), + nn.LayerNorm(self.ninp, eps=1e-12), + nn.Linear(self.ninp, self.ntoken) + )