diff --git a/asparagus/modules/lightning_modules/linear_probe_module.py b/asparagus/modules/lightning_modules/linear_probe_module.py index 5a31199..40ac496 100644 --- a/asparagus/modules/lightning_modules/linear_probe_module.py +++ b/asparagus/modules/lightning_modules/linear_probe_module.py @@ -11,6 +11,7 @@ from torchmetrics.classification import ( MulticlassAUROC, MulticlassAveragePrecision, + MulticlassF1Score, ) from torchvision import transforms from typing import List, Optional @@ -192,6 +193,7 @@ def configure_test_metrics(self): { "AUROC_macro": MulticlassAUROC(num_classes=self.num_classes, average="macro"), "AUPRC_macro": MulticlassAveragePrecision(num_classes=self.num_classes, average="macro"), + "F1_macro": MulticlassF1Score(num_classes=self.num_classes, average="macro"), } ) @@ -205,6 +207,7 @@ def configure_metrics(self, prefix: str): f"{prefix}/{head_name}/auprc_macro": MulticlassAveragePrecision( num_classes=self.num_classes, average="macro" ), + f"{prefix}/{head_name}/f1_macro": MulticlassF1Score(num_classes=self.num_classes, average="macro"), } ) return metrics