Skip to content

Commit afb5f04

Browse files
committed
feat: enhance Focal Loss implementation and add gamma parameter support
- added freeze bio encoder flag
1 parent 86d2ff1 commit afb5f04

1 file changed

Lines changed: 47 additions & 8 deletions

File tree

main.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def compute_class_weights_from_series(s: pd.Series) -> torch.Tensor:
5454

5555

5656
class FocalLoss(nn.Module):
57-
def __init__(self, gamma=3.0, weight=None, reduction="mean"):
57+
def __init__(self, gamma: float = 3.0, weight: torch.Tensor | None = None, reduction: str | None = "mean"):
5858
super().__init__()
5959
self.gamma = gamma
6060
self.ce = nn.CrossEntropyLoss(weight=weight, reduction="none")
@@ -73,17 +73,25 @@ def forward(self, logits, targets):
7373

7474
# --- Custom Trainer for weighted loss (pretraining stage) ---
7575
class WeightedTrainer(Trainer):
76-
def __init__(self, class_weights: torch.Tensor | None = None, use_focal_loss: bool = False, *args, **kwargs):
76+
def __init__(
77+
self,
78+
class_weights: torch.Tensor | None = None,
79+
use_focal_loss: bool = False,
80+
gamma: float = 3.0,
81+
*args,
82+
**kwargs,
83+
):
7784
super().__init__(*args, **kwargs)
7885
self.class_weights = class_weights.to(self.args.device) if class_weights is not None else None
7986
self.use_focal_loss = use_focal_loss
87+
self.gamma = gamma
8088

8189
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
8290
labels = inputs.get("labels")
8391
outputs = model(**{k: v for k, v in inputs.items() if k != "labels"})
8492
logits = outputs.logits
8593
if self.use_focal_loss:
86-
loss_fct = FocalLoss(weight=self.class_weights)
94+
loss_fct = FocalLoss(gamma=self.gamma, weight=self.class_weights)
8795
else:
8896
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
8997
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
@@ -94,7 +102,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
94102
class DualEncoderForSequenceClassification(PreTrainedModel):
95103
config_class = AutoConfig
96104

97-
def __init__(self, config, use_focal_loss: bool = False):
105+
def __init__(
106+
self,
107+
config,
108+
use_focal_loss: bool = False,
109+
gamma: float = 3.0,
110+
):
98111
super().__init__(config)
99112
self.num_labels = config.num_labels
100113
# instantiate two encoders from the pretrained config
@@ -108,6 +121,7 @@ def __init__(self, config, use_focal_loss: bool = False):
108121
self.dropout = nn.Dropout(getattr(config, "hidden_dropout_prob", 0.1))
109122
self.classifier = nn.Linear(hidden_size, config.num_labels)
110123
self.use_focal_loss = use_focal_loss
124+
self.gamma = gamma
111125
self.post_init()
112126

113127
def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=True):
@@ -130,7 +144,7 @@ def forward(self, input_ids=None, attention_mask=None, labels=None, return_dict=
130144
if hasattr(self.config, "class_weights") and self.config.class_weights is not None:
131145
cw = torch.tensor(self.config.class_weights, device=logits.device, dtype=torch.float)
132146
if self.use_focal_loss:
133-
loss_fct = FocalLoss(weight=cw)
147+
loss_fct = FocalLoss(gamma=self.gamma, weight=cw)
134148
else:
135149
loss_fct = nn.CrossEntropyLoss(weight=cw)
136150
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
@@ -289,7 +303,8 @@ def train_pretrain_stage(args, logger):
289303
eval_dataset=val_ds,
290304
compute_metrics=compute_metrics,
291305
class_weights=class_weights,
292-
use_focal_loss=True,
306+
use_focal_loss=args.use_focal_loss,
307+
gamma=args.gamma,
293308
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
294309
)
295310
logger.info("Starting pre-training...")
@@ -315,7 +330,11 @@ def train_main_stage(args, logger, pretrain_trainer, tokenizer, full_df, freeze_
315330
# Build Dual Encoder
316331
config = AutoConfig.from_pretrained(args.model, num_labels=2)
317332
config.class_weights = class_weights.tolist()
318-
combined = DualEncoderForSequenceClassification(config, use_focal_loss=True)
333+
combined = DualEncoderForSequenceClassification(
334+
config,
335+
use_focal_loss=args.use_focal_loss,
336+
gamma=args.gamma,
337+
)
319338

320339
# Load base encoder weights for text encoder (fresh from pretrained)
321340
base_model = AutoModel.from_pretrained(args.model)
@@ -483,6 +502,24 @@ def main():
483502
action="store_true",
484503
help="Skip pretraining stage and train dual encoder from scratch",
485504
)
505+
parser.add_argument(
506+
"--freeze-bio-encoder",
507+
dest="freeze_bio_encoder",
508+
action="store_true",
509+
help="Freeze the bio encoder during main task training",
510+
)
511+
parser.add_argument(
512+
"--use-focal-loss",
513+
dest="use_focal_loss",
514+
action="store_true",
515+
help="Use Focal Loss instead of Cross-Entropy Loss",
516+
)
517+
parser.add_argument(
518+
"--gamma",
519+
type=float,
520+
default=2.0,
521+
help="Gamma parameter for Focal Loss",
522+
)
486523
args = parser.parse_args()
487524

488525
# Setup logging and directories
@@ -550,7 +587,9 @@ def main():
550587
full_df = preprocess_df_texts(full_df, spanish=(args.lang in ["es", "both"]))
551588

552589
# Main Stage
553-
main_trainer, test_dataset = train_main_stage(args, logger, pretrain_trainer, tokenizer, full_df)
590+
main_trainer, test_dataset = train_main_stage(
591+
args, logger, pretrain_trainer, tokenizer, full_df, freeze_bio_encoder=args.freeze_bio_encoder
592+
)
554593

555594
# Evaluate on test set
556595
logger.info("Evaluating main model on test set...")

0 commit comments

Comments
 (0)