Skip to content

Commit 3f37077

Browse files
author
Sharad Sirsat
committed
fix(tests): fix attribute error
fix(tests): fix attribute error
1 parent 1dcd6fa commit 3f37077

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

mmocr/evaluation/evaluator/multi_datasets_evaluator.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,29 @@ def evaluate(self, size: int) -> dict:
9292
metrics_results.update(metric_results)
9393
metric.results.clear()
9494
if is_main_process():
95-
metrics_results = [metrics_results]
95+
averaged_results = self.average_results(metrics_results)
9696
else:
97-
metrics_results = [None] # type: ignore
97+
averaged_results = None
98+
99+
metrics_results = [metrics_results]
98100
broadcast_object_list(metrics_results)
101+
broadcast_object_list([averaged_results])
102+
103+
return metrics_results[0], averaged_results
104+
105+
def average_results(self, metrics_results):
106+
"""Compute the average of metric results across all datasets.
107+
108+
Args:
109+
metrics_results (dict): Evaluation results of all metrics.
110+
111+
Returns:pre
112+
dict: Average evaluation results of all metrics.
113+
"""
114+
averaged_results = {}
115+
num_datasets = len(self.dataset_prefixes)
116+
for metric_name, metric_result in metrics_results.items():
117+
metric_avg = metric_result / num_datasets
118+
averaged_results[metric_name] = metric_avg
99119

100-
return metrics_results[0]
120+
return averaged_results

tests/test_evaluation/test_evaluator/test_multi_datasets_evaluator.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

33
import math
4+
from collections import OrderedDict
45
from typing import Dict, List, Optional
56
from unittest import TestCase
67

@@ -75,7 +76,7 @@ def generate_test_results(size, batch_size, pred, label):
7576
predictions = [
7677
BaseDataElement(pred=pred, label=label) for _ in range(bs)
7778
]
78-
yield (data_batch, predictions)
79+
yield data_batch, predictions
7980

8081

8182
class TestMultiDatasetsEvaluator(TestCase):
@@ -126,3 +127,22 @@ def test_composed_metrics(self):
126127
metrics = evaluator.evaluate(size=size)
127128
self.assertIn('Fake/Toy/accuracy', metrics)
128129
self.assertIn('Fake/accuracy', metrics)
130+
131+
metrics_results = OrderedDict({
132+
'dataset1/metric1/accuracy': 0.9,
133+
'dataset1/metric2/f1_score': 0.8,
134+
'dataset2/metric1/accuracy': 0.85,
135+
'dataset2/metric2/f1_score': 0.75
136+
})
137+
138+
evaluator = MultiDatasetsEvaluator([], [])
139+
averaged_results = evaluator.average_results(metrics_results)
140+
141+
expected_averaged_results = {
142+
'dataset1/metric1/accuracy': 0.9,
143+
'dataset1/metric2/f1_score': 0.8,
144+
'dataset2/metric1/accuracy': 0.85,
145+
'dataset2/metric2/f1_score': 0.75
146+
}
147+
148+
self.assertEqual(averaged_results, expected_averaged_results)

0 commit comments

Comments
 (0)