From a8dbbd13fca08a7894b02299d574b338c02286cb Mon Sep 17 00:00:00 2001 From: Gav Gray Date: Sat, 10 Sep 2022 09:07:10 -0400 Subject: [PATCH 1/3] changes required to use dataparallel --- mingpt/model.py | 8 +++++--- mingpt/trainer.py | 7 +++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mingpt/model.py b/mingpt/model.py index 83ee22dc..623df463 100644 --- a/mingpt/model.py +++ b/mingpt/model.py @@ -84,12 +84,14 @@ def __init__(self, config): act = NewGELU(), dropout = nn.Dropout(config.resid_pdrop), )) - m = self.mlp - self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + # m = self.mlp + # self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward def forward(self, x): x = x + self.attn(self.ln_1(x)) - x = x + self.mlpf(self.ln_2(x)) + m = self.mlp + # x = x + self.mlpf(self.ln_2(x)) + x = x + m.dropout(m.c_proj(m.act(m.c_fc(self.ln_2(x))))) # MLP forward return x class GPT(nn.Module): diff --git a/mingpt/trainer.py b/mingpt/trainer.py index c0d08521..3cb8d70e 100644 --- a/mingpt/trainer.py +++ b/mingpt/trainer.py @@ -7,6 +7,7 @@ from collections import defaultdict import torch +import torch.nn as nn from torch.utils.data.dataloader import DataLoader from mingpt.utils import CfgNode as CN @@ -64,6 +65,10 @@ def run(self): # setup the optimizer self.optimizer = model.configure_optimizers(config) + if torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + model.to(self.device) + # setup the dataloader train_loader = DataLoader( self.train_dataset, @@ -91,6 +96,8 @@ def run(self): # forward the model logits, self.loss = model(x, y) + if self.loss.nelement() > 1: + self.loss = self.loss.mean() # DataParallel can return a vector of losses # backprop and update the parameters model.zero_grad(set_to_none=True) From d29554614d5d93ed43d097e3481de76fd5417355 Mon Sep 17 00:00:00 2001 From: Gav Gray Date: Sat, 10 Sep 2022 09:11:27 -0400 Subject: [PATCH 2/3] make data_parallel optional --- mingpt/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mingpt/trainer.py b/mingpt/trainer.py index 3cb8d70e..b7bac3f6 100644 --- a/mingpt/trainer.py +++ b/mingpt/trainer.py @@ -27,6 +27,7 @@ def get_default_config(): C.betas = (0.9, 0.95) C.weight_decay = 0.1 # only applied on matmul weights C.grad_norm_clip = 1.0 + C.data_parallel = True return C def __init__(self, config, model, train_dataset): @@ -65,7 +66,7 @@ def run(self): # setup the optimizer self.optimizer = model.configure_optimizers(config) - if torch.cuda.device_count() > 1: + if torch.cuda.device_count() > 1 and config.data_parallel: model = nn.DataParallel(model) model.to(self.device) From a362aa626be3926fe198aabd0a2847da0407bb83 Mon Sep 17 00:00:00 2001 From: Gav Gray Date: Sat, 10 Sep 2022 09:13:55 -0400 Subject: [PATCH 3/3] clean up commented code --- mingpt/model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mingpt/model.py b/mingpt/model.py index 623df463..6ee01d04 100644 --- a/mingpt/model.py +++ b/mingpt/model.py @@ -84,13 +84,10 @@ def __init__(self, config): act = NewGELU(), dropout = nn.Dropout(config.resid_pdrop), )) - # m = self.mlp - # self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward def forward(self, x): x = x + self.attn(self.ln_1(x)) m = self.mlp - # x = x + self.mlpf(self.ln_2(x)) x = x + m.dropout(m.c_proj(m.act(m.c_fc(self.ln_2(x))))) # MLP forward return x