@@ -54,7 +54,7 @@ def compute_class_weights_from_series(s: pd.Series) -> torch.Tensor:
5454
5555
5656class FocalLoss (nn .Module ):
57- def __init__ (self , gamma = 3.0 , weight = None , reduction = "mean" ):
57+ def __init__ (self , gamma : float = 3.0 , weight : torch . Tensor | None = None , reduction : str | None = "mean" ):
5858 super ().__init__ ()
5959 self .gamma = gamma
6060 self .ce = nn .CrossEntropyLoss (weight = weight , reduction = "none" )
@@ -73,17 +73,25 @@ def forward(self, logits, targets):
7373
7474# --- Custom Trainer for weighted loss (pretraining stage) ---
7575class WeightedTrainer (Trainer ):
76- def __init__ (self , class_weights : torch .Tensor | None = None , use_focal_loss : bool = False , * args , ** kwargs ):
76+ def __init__ (
77+ self ,
78+ class_weights : torch .Tensor | None = None ,
79+ use_focal_loss : bool = False ,
80+ gamma : float = 3.0 ,
81+ * args ,
82+ ** kwargs ,
83+ ):
7784 super ().__init__ (* args , ** kwargs )
7885 self .class_weights = class_weights .to (self .args .device ) if class_weights is not None else None
7986 self .use_focal_loss = use_focal_loss
87+ self .gamma = gamma
8088
8189 def compute_loss (self , model , inputs , return_outputs = False , num_items_in_batch = None ):
8290 labels = inputs .get ("labels" )
8391 outputs = model (** {k : v for k , v in inputs .items () if k != "labels" })
8492 logits = outputs .logits
8593 if self .use_focal_loss :
86- loss_fct = FocalLoss (weight = self .class_weights )
94+ loss_fct = FocalLoss (gamma = self . gamma , weight = self .class_weights )
8795 else :
8896 loss_fct = nn .CrossEntropyLoss (weight = self .class_weights )
8997 loss = loss_fct (logits .view (- 1 , logits .size (- 1 )), labels .view (- 1 ))
@@ -94,7 +102,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
94102class DualEncoderForSequenceClassification (PreTrainedModel ):
95103 config_class = AutoConfig
96104
97- def __init__ (self , config , use_focal_loss : bool = False ):
105+ def __init__ (
106+ self ,
107+ config ,
108+ use_focal_loss : bool = False ,
109+ gamma : float = 3.0 ,
110+ ):
98111 super ().__init__ (config )
99112 self .num_labels = config .num_labels
100113 # instantiate two encoders from the pretrained config
@@ -108,6 +121,7 @@ def __init__(self, config, use_focal_loss: bool = False):
108121 self .dropout = nn .Dropout (getattr (config , "hidden_dropout_prob" , 0.1 ))
109122 self .classifier = nn .Linear (hidden_size , config .num_labels )
110123 self .use_focal_loss = use_focal_loss
124+ self .gamma = gamma
111125 self .post_init ()
112126
113127 def forward (self , input_ids = None , attention_mask = None , labels = None , return_dict = True ):
@@ -130,7 +144,7 @@ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=
130144 if hasattr (self .config , "class_weights" ) and self .config .class_weights is not None :
131145 cw = torch .tensor (self .config .class_weights , device = logits .device , dtype = torch .float )
132146 if self .use_focal_loss :
133- loss_fct = FocalLoss (weight = cw )
147+ loss_fct = FocalLoss (gamma = self . gamma , weight = cw )
134148 else :
135149 loss_fct = nn .CrossEntropyLoss (weight = cw )
136150 loss = loss_fct (logits .view (- 1 , self .num_labels ), labels .view (- 1 ))
@@ -289,7 +303,8 @@ def train_pretrain_stage(args, logger):
289303 eval_dataset = val_ds ,
290304 compute_metrics = compute_metrics ,
291305 class_weights = class_weights ,
292- use_focal_loss = True ,
306+ use_focal_loss = args .use_focal_loss ,
307+ gamma = args .gamma ,
293308 callbacks = [EarlyStoppingCallback (early_stopping_patience = 3 )],
294309 )
295310 logger .info ("Starting pre-training..." )
@@ -315,7 +330,11 @@ def train_main_stage(args, logger, pretrain_trainer, tokenizer, full_df, freeze_
315330 # Build Dual Encoder
316331 config = AutoConfig .from_pretrained (args .model , num_labels = 2 )
317332 config .class_weights = class_weights .tolist ()
318- combined = DualEncoderForSequenceClassification (config , use_focal_loss = True )
333+ combined = DualEncoderForSequenceClassification (
334+ config ,
335+ use_focal_loss = args .use_focal_loss ,
336+ gamma = args .gamma ,
337+ )
319338
320339 # Load base encoder weights for text encoder (fresh from pretrained)
321340 base_model = AutoModel .from_pretrained (args .model )
@@ -483,6 +502,24 @@ def main():
483502 action = "store_true" ,
484503 help = "Skip pretraining stage and train dual encoder from scratch" ,
485504 )
505+ parser .add_argument (
506+ "--freeze-bio-encoder" ,
507+ dest = "freeze_bio_encoder" ,
508+ action = "store_true" ,
509+ help = "Freeze the bio encoder during main task training" ,
510+ )
511+ parser .add_argument (
512+ "--use-focal-loss" ,
513+ dest = "use_focal_loss" ,
514+ action = "store_true" ,
515+ help = "Use Focal Loss instead of Cross-Entropy Loss" ,
516+ )
517+ parser .add_argument (
518+ "--gamma" ,
519+ type = float ,
520+ default = 2.0 ,
521+ help = "Gamma parameter for Focal Loss" ,
522+ )
486523 args = parser .parse_args ()
487524
488525 # Setup logging and directories
@@ -550,7 +587,9 @@ def main():
550587 full_df = preprocess_df_texts (full_df , spanish = (args .lang in ["es" , "both" ]))
551588
552589 # Main Stage
553- main_trainer , test_dataset = train_main_stage (args , logger , pretrain_trainer , tokenizer , full_df )
590+ main_trainer , test_dataset = train_main_stage (
591+ args , logger , pretrain_trainer , tokenizer , full_df , freeze_bio_encoder = args .freeze_bio_encoder
592+ )
554593
555594 # Evaluate on test set
556595 logger .info ("Evaluating main model on test set..." )
0 commit comments