Skip to content

Commit 936c2a4

Browse files
committed
[feature] Visualizer compatible with MultiTaskDataSample
1 parent 7d850df commit 936c2a4

File tree

2 files changed

+106
-29
lines changed

2 files changed

+106
-29
lines changed

mmpretrain/visualization/visualizer.py

+65-28
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mmengine.visualization.utils import img_from_canvas
1212

1313
from mmpretrain.registry import VISUALIZERS
14-
from mmpretrain.structures import DataSample
14+
from mmpretrain.structures import DataSample, MultiTaskDataSample
1515
from .utils import create_figure, get_adaptive_scale
1616

1717

@@ -99,6 +99,67 @@ def visualize_cls(self,
9999
Returns:
100100
np.ndarray: The visualization image.
101101
"""
102+
103+
def _draw_gt(data_sample: DataSample,
104+
classes: Optional[Sequence[str]],
105+
draw_gt: bool,
106+
texts: Sequence[str],
107+
parent_task: str = ''):
108+
if isinstance(data_sample, MultiTaskDataSample):
109+
for task in data_sample.tasks:
110+
sub_task = f'{parent_task}_{task}' if parent_task else task
111+
_draw_gt(
112+
data_sample.get(task), classes, draw_gt, texts,
113+
sub_task)
114+
else:
115+
if draw_gt and 'gt_label' in data_sample:
116+
idx = data_sample.gt_label.tolist()
117+
class_labels = [''] * len(idx)
118+
if classes is not None:
119+
class_labels = [f' ({classes[i]})' for i in idx]
120+
labels = [
121+
str(idx[i]) + class_labels[i] for i in range(len(idx))
122+
]
123+
prefix = f'{parent_task} Ground truth: ' if parent_task \
124+
else 'Ground truth: '
125+
texts.append(prefix +
126+
('\n' + ' ' * len(prefix)).join(labels))
127+
128+
def _draw_pred(data_sample: DataSample,
129+
classes: Optional[Sequence[str]],
130+
draw_pred: bool,
131+
draw_score: bool,
132+
texts: Sequence[str],
133+
parent_task: str = ''):
134+
if isinstance(data_sample, MultiTaskDataSample):
135+
for task in data_sample.tasks:
136+
sub_task = f'{parent_task}_{task}' if parent_task else task
137+
_draw_pred(
138+
data_sample.get(task), classes, draw_pred, draw_score,
139+
texts, sub_task)
140+
else:
141+
if draw_pred and 'pred_label' in data_sample:
142+
idx = data_sample.pred_label.tolist()
143+
score_labels = [''] * len(idx)
144+
class_labels = [''] * len(idx)
145+
if draw_score and 'pred_score' in data_sample:
146+
score_labels = [
147+
f', {data_sample.pred_score[i].item():.2f}'
148+
for i in idx
149+
]
150+
151+
if classes is not None:
152+
class_labels = [f' ({classes[i]})' for i in idx]
153+
154+
labels = [
155+
str(idx[i]) + score_labels[i] + class_labels[i]
156+
for i in range(len(idx))
157+
]
158+
prefix = f'{parent_task} Prediction: ' if parent_task \
159+
else 'Prediction: '
160+
texts.append(prefix +
161+
('\n' + ' ' * len(prefix)).join(labels))
162+
102163
if self.dataset_meta is not None:
103164
classes = classes or self.dataset_meta.get('classes', None)
104165

@@ -114,33 +175,9 @@ def visualize_cls(self,
114175
texts = []
115176
self.set_image(image)
116177

117-
if draw_gt and 'gt_label' in data_sample:
118-
idx = data_sample.gt_label.tolist()
119-
class_labels = [''] * len(idx)
120-
if classes is not None:
121-
class_labels = [f' ({classes[i]})' for i in idx]
122-
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
123-
prefix = 'Ground truth: '
124-
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
125-
126-
if draw_pred and 'pred_label' in data_sample:
127-
idx = data_sample.pred_label.tolist()
128-
score_labels = [''] * len(idx)
129-
class_labels = [''] * len(idx)
130-
if draw_score and 'pred_score' in data_sample:
131-
score_labels = [
132-
f', {data_sample.pred_score[i].item():.2f}' for i in idx
133-
]
134-
135-
if classes is not None:
136-
class_labels = [f' ({classes[i]})' for i in idx]
137-
138-
labels = [
139-
str(idx[i]) + score_labels[i] + class_labels[i]
140-
for i in range(len(idx))
141-
]
142-
prefix = 'Prediction: '
143-
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
178+
_draw_gt(data_sample, classes, draw_gt, texts)
179+
180+
_draw_pred(data_sample, classes, draw_pred, draw_score, texts)
144181

145182
img_scale = get_adaptive_scale(image.shape[:2])
146183
text_cfg = {

tests/test_visualization/test_visualizer.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import torch
99

10-
from mmpretrain.structures import DataSample
10+
from mmpretrain.structures import DataSample, MultiTaskDataSample
1111
from mmpretrain.visualization import UniversalVisualizer
1212

1313

@@ -123,6 +123,46 @@ def draw_texts(text, font_sizes, *_, **__):
123123
data_sample,
124124
rescale_factor=2.)
125125

126+
def test_visualize_multitask_cls(self):
127+
image = np.ones((1000, 1000, 3), np.uint8)
128+
gt_label = {'task0': {'task00': 2, 'task01': 1}, 'task1': 1}
129+
data_sample = MultiTaskDataSample()
130+
task_sample = DataSample().set_gt_label(
131+
gt_label['task1']).set_pred_label(1).set_pred_score(
132+
torch.tensor([0.1, 0.8, 0.1]))
133+
data_sample.set_field(task_sample, 'task1')
134+
data_sample.set_field(MultiTaskDataSample(), 'task0')
135+
for task_name in gt_label['task0']:
136+
task_sample = DataSample().set_gt_label(
137+
gt_label['task0'][task_name]).set_pred_label(2).set_pred_score(
138+
torch.tensor([0.1, 0.4, 0.5]))
139+
data_sample.task0.set_field(task_sample, task_name)
140+
141+
# Test show
142+
def mock_show(drawn_img, win_name, wait_time):
143+
self.assertFalse((image == drawn_img).all())
144+
self.assertEqual(win_name, 'test_cls')
145+
self.assertEqual(wait_time, 0)
146+
147+
with patch.object(self.vis, 'show', mock_show):
148+
self.vis.visualize_cls(
149+
image=image,
150+
data_sample=data_sample,
151+
show=True,
152+
name='test_cls',
153+
step=2)
154+
155+
# Test storage backend.
156+
save_file = osp.join(self.tmpdir.name,
157+
'vis_data/vis_image/test_cls_2.png')
158+
self.assertTrue(osp.exists(save_file))
159+
160+
# Test out_file
161+
out_file = osp.join(self.tmpdir.name, 'results_2.png')
162+
self.vis.visualize_cls(
163+
image=image, data_sample=data_sample, out_file=out_file)
164+
self.assertTrue(osp.exists(out_file))
165+
126166
def test_visualize_image_retrieval(self):
127167
image = np.ones((10, 10, 3), np.uint8)
128168
data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1])

0 commit comments

Comments
 (0)