@@ -208,6 +208,10 @@ class TorchBaseModel(TorchModelTrainMixin, BaseModelABC):
208208 """
209209
210210 def __init__ (self , dataset : pd .DataFrame ):
211+ dist .init_process_group ("nccl" )
212+ if dist .is_initialized ():
213+ self .local_rank = dist .get_rank ()
214+ torch .cuda .set_device (self .local_rank )
211215 super ().__init__ (dataset )
212216 if dist .is_initialized ():
213217 self .dataloader , self .sampler = self .get_ddp_dataloader ()
@@ -474,7 +478,7 @@ def load_checkpoint(self):
474478 map_location = {"cuda:0" : "mps" , "cuda" : "mps" },
475479 )
476480 else :
477- checkpoint = torch .load (self .checkpoint + "/model.pth" ,map_location = f"cuda:{ local_rank } " )
481+ checkpoint = torch .load (self .checkpoint + "/model.pth" ,map_location = f"cuda:{ self . local_rank } " )
478482 embedding_weights = {
479483 k : v
480484 for k , v in checkpoint .items ()
@@ -483,12 +487,12 @@ def load_checkpoint(self):
483487 self .model .load_state_dict (embedding_weights , strict = False )
484488 self .model .eval ()
485489
486- dist .init_process_group ("nccl" )
487- local_rank = torch .distributed .get_rank ()
488- torch .cuda .set_device (local_rank )
489490
490- self .model = self .model .cuda (local_rank )
491- self .model = nn .parallel .DistributedDataParallel (self .model , device_ids = [local_rank ], output_device = local_rank ,find_unused_parameters = True )
491+ self .model = self .model .cuda (self .local_rank )
492+ self .model = nn .parallel .DistributedDataParallel (self .model ,
493+ device_ids = [self .local_rank ],
494+ output_device = self .local_rank ,
495+ find_unused_parameters = True )
492496 self .optimizer = torch .optim .Adam (self .model .parameters (), lr = self .lr )
493497 self .scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
494498 self .optimizer , mode = "min" , factor = 0.5 , patience = 2
0 commit comments