Skip to content

Commit 30a50e4

Browse files
committed
distribute GPU
1 parent 66ece1f commit 30a50e4

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

src/mixins.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import optuna
99
import random
1010
from torch.utils.data import Subset, DataLoader
11+
import torch.distributed as dist
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -76,7 +77,7 @@ def _train_batch(self, x, y):
7677
acc = correct / len(labels)
7778
loss.backward()
7879
self.optimizer.step()
79-
logger.info(f"loss {loss.item()}")
80+
logger.info(f" Rank {dist.get_rank()} loss {loss.item()}")
8081
mlflow.log_metric("loss", loss.item())
8182
mlflow.log_metric("acc", acc)
8283
mlflow.log_metric("time", time.time())

src/ml.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
import matplotlib.pyplot as plt
1212
from multiprocessing import cpu_count
1313
import torch
14+
import torch.nn as nn
15+
import torch.distributed as dist
1416
from tqdm import tqdm
15-
from torch.utils.data import DataLoader
17+
from torch.utils.data import DataLoader, DistributedSampler
1618
import torch.nn.functional as F
1719
from sklearn.linear_model import LogisticRegression
1820
import lightgbm as lgm
@@ -473,16 +475,14 @@ def load_checkpoint(self):
473475
}
474476
self.model.load_state_dict(embedding_weights, strict=False)
475477
self.model.eval()
476-
import torch
477-
import torch.nn as nn
478-
import torch.distributed as dist
479478

480479
dist.init_process_group("nccl")
481480
local_rank = torch.distributed.get_rank()
482481
torch.cuda.set_device(local_rank)
483482

484483
self.model = self.model.cuda(local_rank)
485484
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank,find_unused_parameters=True)
485+
self.dataset = DistributedSampler(self.dataset)
486486
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
487487
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
488488
self.optimizer, mode="min", factor=0.5, patience=2

0 commit comments

Comments
 (0)