diff --git a/examples/BERT/cross_lingual_mlm_task.py b/examples/BERT/cross_lingual_mlm_task.py index 916b88f93d..1a82db9d7d 100644 --- a/examples/BERT/cross_lingual_mlm_task.py +++ b/examples/BERT/cross_lingual_mlm_task.py @@ -11,16 +11,16 @@ import torch.nn as nn import torch.optim as optim from torch.distributed.optim import DistributedOptimizer +from torch.distributed.pipeline.sync import Pipe from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from data import CC100 -from model import CrossLingualMLMTask from pipeline import SingleProcessPipeline, RPCPipeline, RemoteBaseCPURPC, RemoteBaseCUDARPC from shard_model import XLMRModelShards, MLMShards +from torchtext.experimental.models.utils import count_model_param from torchtext.experimental.transforms import sentencepiece_tokenizer from transforms import PretrainedSPVocab -from torchtext.experimental.models.utils import count_model_param def collate_batch(batch_data, args, mask_id, pad_id, text_transform): @@ -58,7 +58,7 @@ def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, devi return total_loss / (len(data_source) - 1) # Set batch # to 1 for inference -def local_step(model, data, targets, criterion, optimizer, ntokens): +def local_step(model, data, targets, criterion, optimizer, ntokens, args): optimizer.zero_grad() output = model(data) loss = criterion(output.view(-1, ntokens), targets.view(-1)) @@ -69,7 +69,18 @@ def local_step(model, data, targets, criterion, optimizer, ntokens): return res -def dist_step(model, data, targets, criterion, optimizer, ntokens): +def pipe_step(model, data, targets, criterion, optimizer, ntokens, args): + optimizer.zero_grad() + output = model(data).local_value() # Because torch.distributed.pipeline.sync.Pipe.forward returns RRef + loss = criterion(output.view(-1, ntokens), targets.view(-1)) + loss.backward() + res = loss.item() + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) + optimizer.step() + return res + + +def rpc_step(model, data, targets, criterion, optimizer, ntokens, args): with dist_autograd.context() as context_id: output = model(data) loss = criterion(output.view(-1, ntokens), targets.view(-1)) @@ -91,7 +102,7 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform, for batch, (data, targets) in enumerate(dataloader): data = data.to(devices[0]) targets = targets.to(devices[-1]) - loss = step_impl(model, data, targets, criterion, optimizer, ntokens) + loss = step_impl(model, data, targets, criterion, optimizer, ntokens, args) total_loss += loss if batch % args.log_interval == 0 and batch > 0: @@ -99,11 +110,11 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform, elapsed = time.time() - start_time train_loss_log[-1] = cur_loss print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | ' - 'loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch, - len(train_data) // args.batch_size, - last_lr, - elapsed * 1000 / args.log_interval, - cur_loss, math.exp(cur_loss))) + 'loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch, + len(train_data) // args.batch_size, + last_lr, + elapsed * 1000 / args.log_interval, + cur_loss, math.exp(cur_loss))) total_loss = 0 start_time = time.time() @@ -171,18 +182,27 @@ def text_transform(x: str) -> List: print("Allocating memory") if args.pipeline_mode == 'sp': model = SingleProcessPipeline(shards, devices) - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75) - else: + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.75) + elif args.pipeline_mode == 'pipe': + model = Pipe(SingleProcessPipeline(shards, devices, to_device=False), chunks=args.batch_size // args.split_size) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.75) + elif args.pipeline_mode == 'cpu' or args.pipeline_mode == 'cuda': workers = [f"worker{i+1}" for i in range(len(devices))] - model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=(RemoteBaseCUDARPC if args.pipeline_mode == 'cuda' else RemoteBaseCPURPC)) + if args.pipeline_mode == 'cpu': + impl = RemoteBaseCPURPC + elif args.pipeline_mode == 'cuda': + impl = RemoteBaseCUDARPC + model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=impl) optimizer = DistributedOptimizer( optim.Adam, model.parameter_rrefs(), lr=args.lr, ) scheduler = None + else: + raise ValueError("Unsupported pipeline_mode") print("Memory allocated") # input("Memory allocated, check nvidia-smi for memory consumption") @@ -199,16 +219,26 @@ def text_transform(x: str) -> List: epoch_start_time = time.time() last_lr = scheduler.get_last_lr()[0] if scheduler is not None else args.lr + + if args.pipeline_mode == 'sp': + step = local_step + elif args.pipeline_mode == 'pipe': + step = pipe_step + else: + step = rpc_step + + if args.pipeline_mode == 'cpu': + train_devices = ["cpu"] # Because "TensorPipe RPC backend only supports CPU tensors by default, + # please move your tensors to CPU before sending them over RPC" + else: + train_devices = devices + train(model, mask_id, pad_id, train_loss_log, train_data, text_transform, - optimizer, criterion, ntokens, epoch, last_lr, args, - devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"], - local_step if args.pipeline_mode == 'sp' else dist_step) + optimizer, criterion, ntokens, epoch, last_lr, args, train_devices, step) # Turn on evaluation mode which disables dropout. model.eval() - val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args, - devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"], - text_transform) + val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args, train_devices, text_transform) val_loss_log.append(val_loss) print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' @@ -253,7 +283,7 @@ def _forward(x): print('-' * 89) -def run_worker(rank, args): +def run_worker(rank, world_size, args): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256) @@ -265,7 +295,7 @@ def run_worker(rank, args): rpc.init_rpc( "master", rank=rank, - world_size=args.gpus+1, + world_size=world_size, rpc_backend_options=options ) run_main(args) @@ -278,7 +308,7 @@ def run_worker(rank, args): rpc.init_rpc( f"worker{rank}", rank=rank, - world_size=args.gpus+1, + world_size=world_size, rpc_backend_options=options ) pass @@ -337,5 +367,8 @@ def run_worker(rank, args): if args.pipeline_mode == 'sp': run_main(args) + elif args.pipeline_mode == 'pipe': + # Because torch.distributed.pipeline.sync.Pipe.forward returns RRef and requires RPC + mp.spawn(run_worker, args=(1, args), nprocs=1, join=True) else: - mp.spawn(run_worker, args=(args,), nprocs=args.gpus+1, join=True) + mp.spawn(run_worker, args=(args.gpus+1, args), nprocs=args.gpus+1, join=True) diff --git a/examples/BERT/pipeline.py b/examples/BERT/pipeline.py index bc6e8bd7a2..1baab50391 100644 --- a/examples/BERT/pipeline.py +++ b/examples/BERT/pipeline.py @@ -1,9 +1,11 @@ +import concurrent.futures +import threading + import torch -import torch.nn as nn import torch.distributed.rpc as rpc +import torch.nn as nn from torch.distributed.rpc import RRef -import threading -import concurrent.futures + class ToDevice(nn.Module): def __init__(self, device): @@ -15,19 +17,18 @@ def forward(self, x): class SingleProcessPipeline(nn.Sequential): - def __init__(self, shards, devices): + def __init__(self, shards, devices, to_device=True): super().__init__() assert len(shards) == len(devices) self.devices = devices - self.seq = nn.Sequential() - with concurrent.futures.ThreadPoolExecutor() as executor: - concurrent.futures.wait([executor.submit(lambda s, d: s.to(d), shards[i], devices[i]) for i in range(len(shards))]) + concurrent.futures.wait( + [executor.submit(lambda s, d: s.to(d), shards[i], devices[i]) for i in range(len(shards))]) for i, shard in enumerate(shards): - self.seq.add_module(f'Shard({devices[i]})', shard) - if i != len(shards)-1: - self.seq.add_module(f'ToDevice({devices[i+1]})', ToDevice(devices[i+1])) + self.add_module(f'Shard({devices[i]})', shard) + if to_device and i != len(shards) - 1: + self.add_module(f'ToDevice({devices[i + 1]})', ToDevice(devices[i + 1])) class RemoteBaseCPURPC(nn.Module): @@ -66,7 +67,8 @@ class RPCPipeline(nn.Module): def __init__(self, shards, devices, workers, remote_base_class=RemoteBaseCPURPC, split_size=1): super().__init__() self.split_size = split_size - self.shards = [rpc.remote(worker, remote_base_class, args=(shard, device)) for worker, shard, device in zip(workers, shards, devices)] + self.shards = [rpc.remote(worker, remote_base_class, args=(shard, device)) for worker, shard, device in + zip(workers, shards, devices)] def forward(self, xs): # Split the input batch xs into micro-batches, and collect async RPC @@ -87,4 +89,3 @@ def parameter_rrefs(self): for shard in self.shards: remote_params.extend(shard.remote().parameter_rrefs().to_here()) return remote_params - \ No newline at end of file diff --git a/examples/BERT/shard_model.py b/examples/BERT/shard_model.py index b2878dc8f7..4643fce757 100644 --- a/examples/BERT/shard_model.py +++ b/examples/BERT/shard_model.py @@ -1,6 +1,8 @@ import torch.nn as nn + from model import TransformerEncoderLayer, XLMREmbedding + class XLMRModelShards(): def __init__(self, ntoken, ninp, nhead, nhid, dropout=0.5): self.ntoken = ntoken @@ -15,7 +17,6 @@ def encoder_layers(self, nlayers): return nn.TransformerEncoder(self.encoder_layer, nlayers) - class MLMShards(): def __init__(self, ntoken, ninp): self.ntoken = ntoken