@@ -217,6 +217,13 @@ def __init__(self, dataset: pd.DataFrame):
217217 self .dataloader , self .sampler = self .get_ddp_dataloader ()
218218 logger .info (f"Rank { dist .get_rank ()} using DDP" )
219219
220+ def parralle_model (self ):
221+ self .model = self .model .cuda (f"cuda:{ self .local_rank } " )
222+ self .model = nn .parallel .DistributedDataParallel (self .model ,
223+ device_ids = [self .local_rank ],
224+ output_device = self .local_rank ,
225+ find_unused_parameters = True )
226+
220227 def preprocessing (self , data ):
221228 return self .tokenizer (list (data ), return_tensors = "pt" , truncation = True , padding = True )
222229
@@ -409,6 +416,9 @@ def load_checkpoint(self):
409416 ignore_mismatched_sizes = True ,
410417 num_labels = self .out_features ,
411418 )
419+ if dist .is_available () and dist .is_initialized ():
420+ self .model = self .parralle_model ()
421+
412422 self .optimizer = torch .optim .Adam (self .model .parameters (), lr = self .lr )
413423 self .scheduler = torch .optim .lr_scheduler .StepLR (
414424 self .optimizer , step_size = 8 , gamma = 0.248
@@ -445,7 +455,7 @@ class LSTMModel(TorchBaseModel):
445455 name = "LSTM"
446456 dataset_class = TweetDataset
447457 epoch = 1
448- batch_size = 100
458+ batch_size = 32
449459 # test with BCEWithLogitLoss -> 1 logit -> post traitment sigmoïd
450460 out_features = 1
451461 lr = 1e-4
@@ -487,12 +497,9 @@ def load_checkpoint(self):
487497 self .model .load_state_dict (embedding_weights , strict = False )
488498 self .model .eval ()
489499
490-
491- self .model = self .model .cuda (f"cuda:{ self .local_rank } " )
492- self .model = nn .parallel .DistributedDataParallel (self .model ,
493- device_ids = [self .local_rank ],
494- output_device = self .local_rank ,
495- find_unused_parameters = True )
500+ if dist .is_available () and dist .is_initialized ():
501+ self .model = self .parralle_model ()
502+
496503 self .optimizer = torch .optim .Adam (self .model .parameters (), lr = self .lr )
497504 self .scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
498505 self .optimizer , mode = "min" , factor = 0.5 , patience = 2
@@ -513,6 +520,9 @@ def reinit_scheduler_optimizer(self, lr, factor, patience):
513520
514521 def save (self ):
515522 # Create the parent directory saving tokenizer
523+ if dist .is_available () and dist .is_initialized ():
524+ if dist .get_rank () != 0 :
525+ return # ne rien faire sur les autres GPU
516526 self .tokenizer .save_pretrained (self .checkpoint )
517527 torch .save (self .model .state_dict (), self .checkpoint + "/model.pth" )
518528
0 commit comments