Skip to content

Commit c69eb47

Browse files
committed
distribute
1 parent 5a023dc commit c69eb47

1 file changed

Lines changed: 10 additions & 6 deletions

File tree

src/ml.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ class TorchBaseModel(TorchModelTrainMixin, BaseModelABC):
208208
"""
209209

210210
def __init__(self, dataset: pd.DataFrame):
211+
dist.init_process_group("nccl")
212+
if dist.is_initialized():
213+
self.local_rank = dist.get_rank()
214+
torch.cuda.set_device(self.local_rank)
211215
super().__init__(dataset)
212216
if dist.is_initialized():
213217
self.dataloader, self.sampler = self.get_ddp_dataloader()
@@ -474,7 +478,7 @@ def load_checkpoint(self):
474478
map_location={"cuda:0": "mps", "cuda": "mps"},
475479
)
476480
else:
477-
checkpoint = torch.load(self.checkpoint + "/model.pth",map_location=f"cuda:{local_rank}")
481+
checkpoint = torch.load(self.checkpoint + "/model.pth",map_location=f"cuda:{self.local_rank}")
478482
embedding_weights = {
479483
k: v
480484
for k, v in checkpoint.items()
@@ -483,12 +487,12 @@ def load_checkpoint(self):
483487
self.model.load_state_dict(embedding_weights, strict=False)
484488
self.model.eval()
485489

486-
dist.init_process_group("nccl")
487-
local_rank = torch.distributed.get_rank()
488-
torch.cuda.set_device(local_rank)
489490

490-
self.model = self.model.cuda(local_rank)
491-
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank,find_unused_parameters=True)
491+
self.model = self.model.cuda(self.local_rank)
492+
self.model = nn.parallel.DistributedDataParallel(self.model,
493+
device_ids=[self.local_rank],
494+
output_device=self.local_rank,
495+
find_unused_parameters=True)
492496
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
493497
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
494498
self.optimizer, mode="min", factor=0.5, patience=2

0 commit comments

Comments
 (0)