Skip to content

Commit 5d1f3a7

Browse files
committed
distribute train
1 parent 4ad17ec commit 5d1f3a7

3 files changed

Lines changed: 31 additions & 8 deletions

File tree

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,12 @@ Credentials are stored in ~/.config/ovhai/context.json
2222
```bash
2323
uv pip install boto3 awscli ovhai
2424
```
25+
26+
## Run on multi GPU
27+
DEBUG
28+
```bash
29+
export TORCH_DISTRIBUTED_DEBUG=DETAIL
30+
```
31+
```bash
32+
python -m torch.distributed.run --nproc_per_node=2 train.py
33+
```

src/mixins.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import logging
88
import optuna
99
import random
10-
from torch.utils.data import Subset, DataLoader
10+
from torch.utils.data import Subset, DataLoader, DistributedSampler
1111
import torch.distributed as dist
12+
import pandas as pd
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -22,11 +23,15 @@ class TorchModelTrainMixin:
2223
lr: float = 2e-5
2324
device: torch.device
2425

25-
def get_sampled_dataloader(self, frac=0.1):
26+
def sample_dataset(self, frac=0.1):
2627
dataset_size = len(self.dataset)
2728
sample_size = int(frac * dataset_size)
2829
indices = random.sample(range(dataset_size), sample_size)
2930
sampled_dataset = Subset(self.dataset, indices)
31+
return sampled_dataset
32+
33+
def get_sampled_dataloader(self, frac=0.1):
34+
sampled_dataset = self.sample_dataset(frac=frac)
3035
return DataLoader(sampled_dataset, batch_size=self.batch_size, shuffle=True)
3136

3237
def optuna_train(self, run_name:str = "", n_trials:int=30, frac=0.1):
@@ -53,6 +58,13 @@ def objective(self, trial):
5358
self.reinit_scheduler_optimizer(**kwargs)
5459
acc = self.train()
5560
return acc
61+
62+
def get_ddp_dataloader(self, frac=1.0):
63+
sampled_dataset = self.sample_dataset(frac=frac)
64+
sampler = DistributedSampler(sampled_dataset)
65+
dataloader = DataLoader(sampled_dataset, batch_size=self.batch_size, sampler=sampler)
66+
return dataloader, sampler
67+
5668

5769
def _train_batch(self, x, y):
5870
inputs = self.tokenizer(x, return_tensors="pt", truncation=True, padding=True)
@@ -103,7 +115,6 @@ def train(self):
103115
current_acc = self._train_batch(tweets, labels.float())
104116
acc.append(current_acc)
105117
except RuntimeError as e:
106-
raise e
107118
logger.error(e)
108119
del tweets, labels, self.optimizer
109120
gc.collect()

src/ml.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.nn as nn
1515
import torch.distributed as dist
1616
from tqdm import tqdm
17-
from torch.utils.data import DataLoader, DistributedSampler
17+
from torch.utils.data import DataLoader
1818
import torch.nn.functional as F
1919
from sklearn.linear_model import LogisticRegression
2020
import lightgbm as lgm
@@ -206,6 +206,13 @@ class TorchBaseModel(TorchModelTrainMixin, BaseModelABC):
206206
"""
207207
Base class to train and predict on a dataset and register data on MLFLow
208208
"""
209+
210+
def __init__(self, dataset: pd.DataFrame):
211+
super().__init__(dataset)
212+
if dist.is_initialized():
213+
self.dataloader, self.sampler = self.get_ddp_dataloader()
214+
logger.info(f"Rank {dist.get_rank()} using DDP")
215+
209216
def preprocessing(self, data):
210217
return self.tokenizer(list(data), return_tensors="pt", truncation=True, padding=True)
211218

@@ -441,10 +448,6 @@ class LSTMModel(TorchBaseModel):
441448
device = DEVICE
442449
# torch.nn.CrossEntropyLoss()
443450

444-
def __init__(self, dataset: pd.DataFrame):
445-
super().__init__(dataset)
446-
self.dataset = DistributedSampler(self.dataset)
447-
448451
@property
449452
def get_metrics(self) -> dict:
450453
for k, v in self.model.state_dict().items():

0 commit comments

Comments
 (0)