Skip to content

Commit a56fb40

Browse files
committed
distribute
1 parent d6aed1c commit a56fb40

3 files changed

Lines changed: 18 additions & 5 deletions

File tree

mlruns/0/meta.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
artifact_location: file:///Users/wonters/Desktop/openclassroom/projets/sentiment_analyses/sentimental_analyses/mlruns/0
2-
creation_time: 1744984721760
1+
artifact_location: file:///home/debian/sentimental_analyses/mlruns/0
2+
creation_time: 1745341000952
33
experiment_id: '0'
4-
last_update_time: 1744984721760
4+
last_update_time: 1745341000952
55
lifecycle_stage: active
66
name: Default

src/ml.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ class LSTMModel(TorchBaseModel):
432432
name = "LSTM"
433433
dataset_class = TweetDataset
434434
epoch = 1
435-
batch_size = 120
435+
batch_size = 100
436436
# test with BCEWithLogitLoss -> 1 logit -> post traitment sigmoïd
437437
out_features = 1
438438
lr = 1e-4
@@ -473,6 +473,16 @@ def load_checkpoint(self):
473473
}
474474
self.model.load_state_dict(embedding_weights, strict=False)
475475
self.model.eval()
476+
import torch
477+
import torch.nn as nn
478+
import torch.distributed as dist
479+
480+
dist.init_process_group("nccl")
481+
local_rank = torch.distributed.get_rank()
482+
torch.cuda.set_device(local_rank)
483+
484+
self.model = self.model.cuda(local_rank)
485+
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank,find_unused_parameters=True)
476486
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
477487
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
478488
self.optimizer, mode="min", factor=0.5, patience=2

train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
## Import lightgbm avoiding segfault error, protection against segfault
22
#import lightgbm as lgb
3+
import logging
34
from src.ml import LightGBMModel, load_data, LSTMModel
45

6+
logging.basicConfig(level=logging.INFO)
7+
58
df = load_data('../data/training.1600000.processed.noemoticon.csv')
6-
#df = df.sample(frac=1, random_state=42)
9+
df = df.sample(frac=0.1, random_state=42)
710
model = LSTMModel(df)
811
model.train()

0 commit comments

Comments
 (0)