Skip to content

Commit 93b6a50

Browse files
committed
LSTM optim optuna
1 parent 0c83668 commit 93b6a50

3 files changed

Lines changed: 52 additions & 25 deletions

File tree

src/mixins.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,28 +38,26 @@ def optuna_train(self, run_name:str = "", n_trials:int=30, frac=0.1):
3838
study = optuna.create_study(direction="maximize",
3939
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=1))
4040
study.optimize(self.objective, n_trials=n_trials)
41-
42-
43-
def objective(self, trial):
41+
42+
def params_optim(self, trial):
4443
lr = trial.suggest_loguniform('lr', 1e-6, 1e-3)
4544
gamma = trial.suggest_float('gamma', 0.1, 0.9)
4645
step_size = trial.suggest_int('step_size', 2, 10)
46+
return {'lr': lr, 'gamma': gamma, 'step_size': step_size}
47+
48+
def objective(self, trial):
49+
kwargs = self.params_optim(trial)
4750
with mlflow.start_run(nested=True):
48-
mlflow.log_params({
49-
"lr": lr,
50-
"gamma": gamma,
51-
"step_size": step_size
52-
})
53-
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
54-
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=step_size, gamma=gamma)
51+
mlflow.log_params(kwargs)
52+
self.reinit_scheduler_optimizer(**kwargs)
5553
acc = self.train()
5654
return acc
5755

5856
def _train_batch(self, x, y):
5957
inputs = self.tokenizer(x, return_tensors="pt", truncation=True, padding=True)
6058
if isinstance(inputs, dict) and inputs["input_ids"]:
6159
inputs["input_ids"] = inputs["input_ids"].float()
62-
if isinstance(y, torch.Tensor) and y.dtype == torch.float32:
60+
if False and isinstance(y, torch.Tensor) and y.dtype == torch.float32:
6361
labels = y.long()
6462
else:
6563
labels = y.float()
@@ -69,9 +67,10 @@ def _train_batch(self, x, y):
6967
outputs = self.model(**inputs)
7068
try:
7169
loss = self.criterion(outputs.logits, labels)
70+
_, preds = torch.max(outputs.logits, dim=1)
7271
except AttributeError:
7372
loss = self.criterion(outputs, labels)
74-
_, preds = torch.max(outputs.logits, dim=1)
73+
preds = outputs
7574
correct = (preds == labels).sum().item()
7675
acc = correct / len(labels)
7776
loss.backward()
@@ -99,8 +98,10 @@ def train(self):
9998
for epoch in tqdm(range(self.epoch)):
10099
for tweets, labels in tqdm(self.dataloader):
101100
try:
102-
acc.append(self._train_batch(tweets, labels.float()))
101+
current_acc = self._train_batch(tweets, labels.float())
102+
acc.append(current_acc)
103103
except RuntimeError as e:
104+
raise e
104105
logger.error(e)
105106
del tweets, labels, self.optimizer
106107
gc.collect()
@@ -120,7 +121,10 @@ def train(self):
120121
logger.info(
121122
f"CUDA allocated memory: {torch.cuda.memory_allocated()}"
122123
)
123-
self.scheduler.step()
124+
if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
125+
self.scheduler.step(current_acc)
126+
else:
127+
self.scheduler.step()
124128
super().train()
125129
return sum(acc)/len(acc)
126130

src/ml.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,16 @@ class BertModel(TorchBaseModel):
357357
lr = 2.561e-4
358358
device = DEVICE
359359

360+
def params_optim(self, trial):
361+
lr = trial.suggest_loguniform('lr', 1e-6, 1e-3)
362+
gamma = trial.suggest_float('gamma', 0.1, 0.9)
363+
step_size = trial.suggest_int('step_size', 2, 10)
364+
return {'lr': lr, 'gamma': gamma, 'step_size': step_size}
365+
366+
def reinit_scheduler_optimizer(self, lr, gamma, step_size):
367+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
368+
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=step_size, gamma=gamma)
369+
360370
def load_checkpoint(self):
361371
if Path(self.checkpoint).exists():
362372
self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
@@ -449,27 +459,40 @@ def load_checkpoint(self):
449459
self.model.load_state_dict(embedding_weights, strict=False)
450460
self.model.eval()
451461
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
452-
# self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.1)
453462
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
454463
self.optimizer, mode="min", factor=0.5, patience=2
455464
)
456465
self.criterion = torch.nn.BCEWithLogitsLoss()
457466

467+
def params_optim(self, trial):
468+
lr = trial.suggest_loguniform('lr', 1e-6, 1e-3)
469+
factor = trial.suggest_float('factor', 0.1, 0.9)
470+
patience = trial.suggest_int('patience', 2, 10)
471+
return {'lr': lr, 'factor': factor, 'patience': patience}
472+
473+
def reinit_scheduler_optimizer(self, lr, factor, patience):
474+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
475+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
476+
self.optimizer, mode="min", factor=factor, patience=patience
477+
)
478+
458479
def save(self):
459480
# Create the parent directory saving tokenizer
460481
self.tokenizer.save_pretrained(self.checkpoint)
461482
torch.save(self.model.state_dict(), self.checkpoint + "/model.pth")
462483

463484
def predict(self, x):
464-
inputs = self.preprocessing(x)
465-
inputs = inputs.to(self.device)
466-
with torch.no_grad():
467-
outputs = self.model(**inputs)
468-
# Appliquer sigmoïde sur les 4 prédictions
469-
probs = torch.sigmoid(outputs)
470-
# Convertir en classes (0 ou 1) en utilisant un seuil de 0.5
471-
predicted_classes = (probs > 0.5).int()
472-
return predicted_classes.tolist()
485+
predicted_class = []
486+
for i in range(0, len(x), self.batch_size):
487+
inputs = self.preprocessing(x[i:i+self.batch_size])
488+
inputs = inputs.to(self.device)
489+
with torch.no_grad():
490+
outputs = self.model(**inputs)
491+
# Appliquer sigmoïde sur les 4 prédictions
492+
probs = torch.sigmoid(outputs)
493+
# Convertir en classes (0 ou 1) en utilisant un seuil de 0.5
494+
predicted_class.extend((probs > 0.5).int().tolist())
495+
return predicted_class
473496

474497

475498
def split_data(df: pd.DataFrame, shuffle: bool = True):

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
file = "../data/training.1600000.processed.noemoticon.csv"
99
original_df = load_data(file)
1010
model = LSTMModel(original_df)
11-
model.optuna_train(n_trials=5, frac=0.01)
11+
model.optuna_train(n_trials=5, frac=0.001)

0 commit comments

Comments
 (0)