@@ -357,6 +357,16 @@ class BertModel(TorchBaseModel):
357357 lr = 2.561e-4
358358 device = DEVICE
359359
360+ def params_optim (self , trial ):
361+ lr = trial .suggest_loguniform ('lr' , 1e-6 , 1e-3 )
362+ gamma = trial .suggest_float ('gamma' , 0.1 , 0.9 )
363+ step_size = trial .suggest_int ('step_size' , 2 , 10 )
364+ return {'lr' : lr , 'gamma' : gamma , 'step_size' : step_size }
365+
366+ def reinit_scheduler_optimizer (self , lr , gamma , step_size ):
367+ self .optimizer = torch .optim .Adam (self .model .parameters (), lr = lr )
368+ self .scheduler = torch .optim .lr_scheduler .StepLR (self .optimizer , step_size = step_size , gamma = gamma )
369+
360370 def load_checkpoint (self ):
361371 if Path (self .checkpoint ).exists ():
362372 self .tokenizer = AutoTokenizer .from_pretrained (self .checkpoint )
@@ -449,27 +459,40 @@ def load_checkpoint(self):
449459 self .model .load_state_dict (embedding_weights , strict = False )
450460 self .model .eval ()
451461 self .optimizer = torch .optim .Adam (self .model .parameters (), lr = self .lr )
452- # self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.1)
453462 self .scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
454463 self .optimizer , mode = "min" , factor = 0.5 , patience = 2
455464 )
456465 self .criterion = torch .nn .BCEWithLogitsLoss ()
457466
467+ def params_optim (self , trial ):
468+ lr = trial .suggest_loguniform ('lr' , 1e-6 , 1e-3 )
469+ factor = trial .suggest_float ('factor' , 0.1 , 0.9 )
470+ patience = trial .suggest_int ('patience' , 2 , 10 )
471+ return {'lr' : lr , 'factor' : factor , 'patience' : patience }
472+
473+ def reinit_scheduler_optimizer (self , lr , factor , patience ):
474+ self .optimizer = torch .optim .Adam (self .model .parameters (), lr = lr )
475+ self .scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
476+ self .optimizer , mode = "min" , factor = factor , patience = patience
477+ )
478+
458479 def save (self ):
459480 # Create the parent directory saving tokenizer
460481 self .tokenizer .save_pretrained (self .checkpoint )
461482 torch .save (self .model .state_dict (), self .checkpoint + "/model.pth" )
462483
463484 def predict (self , x ):
464- inputs = self .preprocessing (x )
465- inputs = inputs .to (self .device )
466- with torch .no_grad ():
467- outputs = self .model (** inputs )
468- # Appliquer sigmoïde sur les 4 prédictions
469- probs = torch .sigmoid (outputs )
470- # Convertir en classes (0 ou 1) en utilisant un seuil de 0.5
471- predicted_classes = (probs > 0.5 ).int ()
472- return predicted_classes .tolist ()
485+ predicted_class = []
486+ for i in range (0 , len (x ), self .batch_size ):
487+ inputs = self .preprocessing (x [i :i + self .batch_size ])
488+ inputs = inputs .to (self .device )
489+ with torch .no_grad ():
490+ outputs = self .model (** inputs )
491+ # Appliquer sigmoïde sur les 4 prédictions
492+ probs = torch .sigmoid (outputs )
493+ # Convertir en classes (0 ou 1) en utilisant un seuil de 0.5
494+ predicted_class .extend ((probs > 0.5 ).int ().tolist ())
495+ return predicted_class
473496
474497
475498def split_data (df : pd .DataFrame , shuffle : bool = True ):
0 commit comments