3333 MultilabelAUROC ,
3434 MultilabelAveragePrecision ,
3535)
36+ from torchmetrics .regression import PearsonCorrCoef
3637from torchmetrics .text import BLEUScore
3738from torchmetrics .utilities .checks import _allclose_recursive
3839from unittests ._helpers import seed_all
@@ -328,30 +329,35 @@ def compute(self):
328329 "metrics, expected, preds, target" ,
329330 [
330331 # single metric forms its own compute group
331- (MulticlassAccuracy (num_classes = 3 ), {0 : ["MulticlassAccuracy" ]}, _mc_preds , _mc_target ),
332+ pytest .param (
333+ MulticlassAccuracy (num_classes = 3 ), {0 : ["MulticlassAccuracy" ]}, _mc_preds , _mc_target , id = "single_metric"
334+ ),
332335 # two metrics of same class forms a compute group
333- (
336+ pytest . param (
334337 {"acc0" : MulticlassAccuracy (num_classes = 3 ), "acc1" : MulticlassAccuracy (num_classes = 3 )},
335338 {0 : ["acc0" , "acc1" ]},
336339 _mc_preds ,
337340 _mc_target ,
341+ id = "two_metrics_of_same_class" ,
338342 ),
339343 # two metrics from registry forms a compute group
340- (
344+ pytest . param (
341345 [MulticlassPrecision (num_classes = 3 ), MulticlassRecall (num_classes = 3 )],
342346 {0 : ["MulticlassPrecision" , "MulticlassRecall" ]},
343347 _mc_preds ,
344348 _mc_target ,
349+ id = "two_metrics_from_registry" ,
345350 ),
346351 # two metrics from different classes gives two compute groups
347- (
352+ pytest . param (
348353 [MulticlassConfusionMatrix (num_classes = 3 ), MulticlassRecall (num_classes = 3 )],
349354 {0 : ["MulticlassConfusionMatrix" ], 1 : ["MulticlassRecall" ]},
350355 _mc_preds ,
351356 _mc_target ,
357+ id = "two_metrics_from_different_classes" ,
352358 ),
353359 # multi group multi metric
354- (
360+ pytest . param (
355361 [
356362 MulticlassConfusionMatrix (num_classes = 3 ),
357363 MulticlassCohenKappa (num_classes = 3 ),
@@ -361,9 +367,10 @@ def compute(self):
361367 {0 : ["MulticlassConfusionMatrix" , "MulticlassCohenKappa" ], 1 : ["MulticlassRecall" , "MulticlassPrecision" ]},
362368 _mc_preds ,
363369 _mc_target ,
370+ id = "multi_group_multi_metric" ,
364371 ),
365372 # Complex example
366- (
373+ pytest . param (
367374 {
368375 "acc" : MulticlassAccuracy (num_classes = 3 ),
369376 "acc2" : MulticlassAccuracy (num_classes = 3 ),
@@ -375,19 +382,21 @@ def compute(self):
375382 {0 : ["acc" , "acc2" , "f1" , "recall" ], 1 : ["acc3" ], 2 : ["confmat" ]},
376383 _mc_preds ,
377384 _mc_target ,
385+ id = "complex_example" ,
378386 ),
379387 # With list states
380- (
388+ pytest . param (
381389 [
382390 MulticlassAUROC (num_classes = 3 , average = "macro" ),
383391 MulticlassAveragePrecision (num_classes = 3 , average = "macro" ),
384392 ],
385393 {0 : ["MulticlassAUROC" , "MulticlassAveragePrecision" ]},
386394 _mc_preds ,
387395 _mc_target ,
396+ id = "with_list_states" ,
388397 ),
389398 # Nested collections
390- (
399+ pytest . param (
391400 [
392401 MetricCollection (
393402 MultilabelAUROC (num_labels = 3 , average = "micro" ),
@@ -410,6 +419,7 @@ def compute(self):
410419 },
411420 _ml_preds ,
412421 _ml_target ,
422+ id = "nested_collections" ,
413423 ),
414424 ],
415425)
@@ -796,3 +806,39 @@ def test_collection_update():
796806
797807 for k , v in expected .items ():
798808 torch .testing .assert_close (actual = actual .get (k ), expected = v , rtol = 1e-4 , atol = 1e-4 )
809+
810+
811+ def test_collection_state_being_re_established_after_copy ():
812+ """Check that shared metrics states when using compute groups are re-established after a copy.
813+
814+ See issue: https://github.com/Lightning-AI/torchmetrics/issues/2896
815+
816+ """
817+ m1 , m2 = PearsonCorrCoef (), PearsonCorrCoef ()
818+ m12 = MetricCollection ({"m1" : m1 , "m2" : m2 }, compute_groups = True )
819+ x1 , y1 = torch .randn (100 ), torch .randn (100 )
820+ m12 .update (x1 , y1 )
821+ assert m12 .compute_groups == {0 : ["m1" , "m2" ]}
822+
823+ # Check that the states are pointing to the same location
824+ assert not m12 ._state_is_copy
825+ assert m12 .m1 .mean_x .data_ptr () == m12 .m2 .mean_x .data_ptr (), "States should point to the same location"
826+
827+ # Break the references between the states
828+ _ = m12 .items ()
829+ assert m12 ._state_is_copy
830+ assert m12 .m1 .mean_x .data_ptr () != m12 .m2 .mean_x .data_ptr (), "States should not point to the same location"
831+
832+ # Update should restore the references between the states
833+ x2 , y2 = torch .randn (100 ), torch .randn (100 )
834+
835+ m12 .update (x2 , y2 )
836+ assert not m12 ._state_is_copy
837+ assert m12 .m1 .mean_x .data_ptr () == m12 .m2 .mean_x .data_ptr (), "States should point to the same location"
838+
839+ x3 , y3 = torch .randn (100 ), torch .randn (100 )
840+ m12 .update (x3 , y3 )
841+
842+ assert not m12 ._state_is_copy
843+ assert m12 .m1 .mean_x .data_ptr () == m12 .m2 .mean_x .data_ptr (), "States should point to the same location"
844+ assert m12 ._equal_metric_states (m12 .m1 , m12 .m2 )
0 commit comments