Skip to content

Commit 4af2df3

Browse files
committed
distribute bert
1 parent 0a776f5 commit 4af2df3

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

src/mixins.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def _train_batch(self, x, y):
7070
inputs = self.tokenizer(x, return_tensors="pt", truncation=True, padding=True)
7171
if isinstance(inputs, dict) and inputs["input_ids"]:
7272
inputs["input_ids"] = inputs["input_ids"].float()
73-
if False and isinstance(y, torch.Tensor) and y.dtype == torch.float32:
73+
# todo : fix this case for bert vs lstm
74+
if True and isinstance(y, torch.Tensor) and y.dtype == torch.float32:
7475
labels = y.long()
7576
else:
7677
labels = y.float()
@@ -84,6 +85,7 @@ def _train_batch(self, x, y):
8485
_, preds = torch.max(outputs.logits, dim=1)
8586
except AttributeError:
8687
loss = self.criterion(outputs, labels)
88+
# todo : error on accuracy return always 0
8789
preds = outputs
8890
correct = (preds == labels).sum().item()
8991
acc = correct / len(labels)

src/ml.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,11 @@ def load_checkpoint(self):
426426
self.criterion = torch.nn.CrossEntropyLoss()
427427

428428
def save(self):
429-
self.model.save_pretrained(self.checkpoint)
429+
print(dist.is_available(), dist.is_initialized())
430+
if dist.is_available() and dist.is_initialized():
431+
self.model.module.save_pretrained(self.checkpoint)
432+
else:
433+
self.model.save_pretrained(self.checkpoint)
430434
self.tokenizer.save_pretrained(self.checkpoint)
431435

432436
def predict(self, x: list):

0 commit comments

Comments
 (0)