Skip to content

Commit 902745f

Browse files
sivanravidosmmdanzigerclaude
authored
Filter perplexity at get_metrics() construction instead of post-hoc pop (#38)
Removes dependency on MetricCollection.pop() which is absent in torchmetrics>=1.2, unblocking the torchmetrics unpin. Callers pass exclude={"perplexity"} when multiple tasks share a metric_key. Co-authored-by: Michael-Danziger <michael.danziger@ibm.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 47194b4 commit 902745f

3 files changed

Lines changed: 13 additions & 8 deletions

File tree

bmfm_targets/training/losses/task.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,18 @@ def extract_metric_inputs(
197197

198198
return self.metric_key, model_outputs, gt_labels
199199

200-
def get_metrics(self) -> MetricCollection:
200+
def get_metrics(self, exclude: set[str] | None = None) -> MetricCollection:
201201
"""
202202
Get metric collection for this task.
203203
204204
Uses objective's default metrics unless overridden via metrics parameter.
205205
Filters metrics based on output size (classification vs regression).
206206
207+
Parameters
208+
----------
209+
exclude:
210+
Optional set of metric names to omit from the collection.
211+
207212
Returns
208213
-------
209214
MetricCollection: Collection of metrics for this task
@@ -227,6 +232,7 @@ def get_metrics(self) -> MetricCollection:
227232
{
228233
mt["name"]: metrics.get_metric_object(mt, num_classes)
229234
for mt in metric_configs
235+
if not (exclude and mt["name"] in exclude)
230236
}
231237
)
232238

bmfm_targets/training/modules/base.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,12 +249,11 @@ def initialize_metrics(self):
249249

250250
metrics_dict = {}
251251
for loss_task in self.loss_tasks:
252-
metrics_dict[loss_task.metric_key] = loss_task.get_metrics()
253-
# Remove perplexity when multiple tasks share a metric_key
254-
# (perplexity needs logits, but duplicated keys use predictions)
255-
if metric_key_counts[loss_task.metric_key] > 1:
256-
if "perplexity" in metrics_dict[loss_task.metric_key]:
257-
metrics_dict[loss_task.metric_key].pop("perplexity")
252+
# Perplexity needs logits, but duplicated metric_keys use predictions
253+
exclude = (
254+
{"perplexity"} if metric_key_counts[loss_task.metric_key] > 1 else None
255+
)
256+
metrics_dict[loss_task.metric_key] = loss_task.get_metrics(exclude=exclude)
258257

259258
self.train_metrics = MultitaskWrapper(metrics_dict).clone()
260259
self.val_metrics = MultitaskWrapper(metrics_dict).clone()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"hydra-core",
3333
"clearml>1.13,<2",
3434
"rich",
35-
"torchmetrics==1.1.0",
35+
"torchmetrics",
3636
"tensorboardX",
3737
"pandas>=2,<3",
3838
"einops",

0 commit comments

Comments
 (0)