Skip to content

Commit 640d348

Browse files
committed
distribute
1 parent a701f25 commit 640d348

1 file changed

Lines changed: 17 additions & 7 deletions

File tree

src/ml.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@ def __init__(self, dataset: pd.DataFrame):
217217
self.dataloader, self.sampler = self.get_ddp_dataloader()
218218
logger.info(f"Rank {dist.get_rank()} using DDP")
219219

220+
def parralle_model(self):
221+
self.model = self.model.cuda(f"cuda:{self.local_rank}")
222+
self.model = nn.parallel.DistributedDataParallel(self.model,
223+
device_ids=[self.local_rank],
224+
output_device=self.local_rank,
225+
find_unused_parameters=True)
226+
220227
def preprocessing(self, data):
221228
return self.tokenizer(list(data), return_tensors="pt", truncation=True, padding=True)
222229

@@ -409,6 +416,9 @@ def load_checkpoint(self):
409416
ignore_mismatched_sizes=True,
410417
num_labels=self.out_features,
411418
)
419+
if dist.is_available() and dist.is_initialized():
420+
self.model = self.parralle_model()
421+
412422
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
413423
self.scheduler = torch.optim.lr_scheduler.StepLR(
414424
self.optimizer, step_size=8, gamma=0.248
@@ -445,7 +455,7 @@ class LSTMModel(TorchBaseModel):
445455
name = "LSTM"
446456
dataset_class = TweetDataset
447457
epoch = 1
448-
batch_size = 100
458+
batch_size = 32
449459
# test with BCEWithLogitLoss -> 1 logit -> post traitment sigmoïd
450460
out_features = 1
451461
lr = 1e-4
@@ -487,12 +497,9 @@ def load_checkpoint(self):
487497
self.model.load_state_dict(embedding_weights, strict=False)
488498
self.model.eval()
489499

490-
491-
self.model = self.model.cuda(f"cuda:{self.local_rank}")
492-
self.model = nn.parallel.DistributedDataParallel(self.model,
493-
device_ids=[self.local_rank],
494-
output_device=self.local_rank,
495-
find_unused_parameters=True)
500+
if dist.is_available() and dist.is_initialized():
501+
self.model = self.parralle_model()
502+
496503
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
497504
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
498505
self.optimizer, mode="min", factor=0.5, patience=2
@@ -513,6 +520,9 @@ def reinit_scheduler_optimizer(self, lr, factor, patience):
513520

514521
def save(self):
515522
# Create the parent directory saving tokenizer
523+
if dist.is_available() and dist.is_initialized():
524+
if dist.get_rank() != 0:
525+
return # ne rien faire sur les autres GPU
516526
self.tokenizer.save_pretrained(self.checkpoint)
517527
torch.save(self.model.state_dict(), self.checkpoint + "/model.pth")
518528

0 commit comments

Comments
 (0)