Skip to content

Commit afc3d6f

Browse files
committed
[feature] Visualizer compatible with MultiTaskDataSample
1 parent 7d850df commit afc3d6f

File tree

2 files changed

+103
-28
lines changed

2 files changed

+103
-28
lines changed

mmpretrain/visualization/visualizer.py

+62-27
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mmengine.visualization.utils import img_from_canvas
1212

1313
from mmpretrain.registry import VISUALIZERS
14-
from mmpretrain.structures import DataSample
14+
from mmpretrain.structures import DataSample, MultiTaskDataSample
1515
from .utils import create_figure, get_adaptive_scale
1616

1717

@@ -114,33 +114,9 @@ def visualize_cls(self,
114114
texts = []
115115
self.set_image(image)
116116

117-
if draw_gt and 'gt_label' in data_sample:
118-
idx = data_sample.gt_label.tolist()
119-
class_labels = [''] * len(idx)
120-
if classes is not None:
121-
class_labels = [f' ({classes[i]})' for i in idx]
122-
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
123-
prefix = 'Ground truth: '
124-
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
125-
126-
if draw_pred and 'pred_label' in data_sample:
127-
idx = data_sample.pred_label.tolist()
128-
score_labels = [''] * len(idx)
129-
class_labels = [''] * len(idx)
130-
if draw_score and 'pred_score' in data_sample:
131-
score_labels = [
132-
f', {data_sample.pred_score[i].item():.2f}' for i in idx
133-
]
134-
135-
if classes is not None:
136-
class_labels = [f' ({classes[i]})' for i in idx]
117+
self.draw_gt(data_sample, classes, draw_gt, texts)
137118

138-
labels = [
139-
str(idx[i]) + score_labels[i] + class_labels[i]
140-
for i in range(len(idx))
141-
]
142-
prefix = 'Prediction: '
143-
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
119+
self.draw_pred(data_sample, classes, draw_pred, draw_score, texts)
144120

145121
img_scale = get_adaptive_scale(image.shape[:2])
146122
text_cfg = {
@@ -167,6 +143,65 @@ def visualize_cls(self,
167143

168144
return drawn_img
169145

146+
def draw_pred(self,
147+
data_sample: DataSample,
148+
classes: Optional[Sequence[str]],
149+
draw_pred: bool,
150+
draw_score: bool,
151+
texts: Sequence[str],
152+
parent_task: str = ''):
153+
if isinstance(data_sample, MultiTaskDataSample):
154+
for task in data_sample.tasks:
155+
sub_task = f'{parent_task}_{task}' if parent_task else task
156+
self.draw_pred(
157+
data_sample.get(task), classes, draw_pred, draw_score,
158+
texts, sub_task)
159+
else:
160+
if draw_pred and 'pred_label' in data_sample:
161+
idx = data_sample.pred_label.tolist()
162+
score_labels = [''] * len(idx)
163+
class_labels = [''] * len(idx)
164+
if draw_score and 'pred_score' in data_sample:
165+
score_labels = [
166+
f', {data_sample.pred_score[i].item():.2f}'
167+
for i in idx
168+
]
169+
170+
if classes is not None:
171+
class_labels = [f' ({classes[i]})' for i in idx]
172+
173+
labels = [
174+
str(idx[i]) + score_labels[i] + class_labels[i]
175+
for i in range(len(idx))
176+
]
177+
prefix = f'{parent_task} Prediction: ' if parent_task \
178+
else 'Prediction: '
179+
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
180+
181+
def draw_gt(self,
182+
data_sample: DataSample,
183+
classes: Optional[Sequence[str]],
184+
draw_gt: bool,
185+
texts: Sequence[str],
186+
parent_task: str = ''):
187+
if isinstance(data_sample, MultiTaskDataSample):
188+
for task in data_sample.tasks:
189+
sub_task = f'{parent_task}_{task}' if parent_task else task
190+
self.draw_gt(
191+
data_sample.get(task), classes, draw_gt, texts, sub_task)
192+
else:
193+
if draw_gt and 'gt_label' in data_sample:
194+
idx = data_sample.gt_label.tolist()
195+
class_labels = [''] * len(idx)
196+
if classes is not None:
197+
class_labels = [f' ({classes[i]})' for i in idx]
198+
labels = [
199+
str(idx[i]) + class_labels[i] for i in range(len(idx))
200+
]
201+
prefix = f'{parent_task} Ground truth: ' if parent_task \
202+
else 'Ground truth: '
203+
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
204+
170205
@master_only
171206
def visualize_image_retrieval(self,
172207
image: np.ndarray,

tests/test_visualization/test_visualizer.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import torch
99

10-
from mmpretrain.structures import DataSample
10+
from mmpretrain.structures import DataSample, MultiTaskDataSample
1111
from mmpretrain.visualization import UniversalVisualizer
1212

1313

@@ -123,6 +123,46 @@ def draw_texts(text, font_sizes, *_, **__):
123123
data_sample,
124124
rescale_factor=2.)
125125

126+
def test_visualize_multitask_cls(self):
127+
image = np.ones((1000, 1000, 3), np.uint8)
128+
gt_label = {'task0': {'task00': 2, 'task01': 1}, 'task1': 1}
129+
data_sample = MultiTaskDataSample()
130+
task_sample = DataSample().set_gt_label(
131+
gt_label['task1']).set_pred_label(1).set_pred_score(
132+
torch.tensor([0.1, 0.8, 0.1]))
133+
data_sample.set_field(task_sample, 'task1')
134+
data_sample.set_field(MultiTaskDataSample(), 'task0')
135+
for task_name in gt_label['task0']:
136+
task_sample = DataSample().set_gt_label(
137+
gt_label['task0'][task_name]).set_pred_label(2).set_pred_score(
138+
torch.tensor([0.1, 0.4, 0.5]))
139+
data_sample.task0.set_field(task_sample, task_name)
140+
141+
# Test show
142+
def mock_show(drawn_img, win_name, wait_time):
143+
self.assertFalse((image == drawn_img).all())
144+
self.assertEqual(win_name, 'test_cls')
145+
self.assertEqual(wait_time, 0)
146+
147+
with patch.object(self.vis, 'show', mock_show):
148+
self.vis.visualize_cls(
149+
image=image,
150+
data_sample=data_sample,
151+
show=True,
152+
name='test_cls',
153+
step=2)
154+
155+
# Test storage backend.
156+
save_file = osp.join(self.tmpdir.name,
157+
'vis_data/vis_image/test_cls_2.png')
158+
self.assertTrue(osp.exists(save_file))
159+
160+
# Test out_file
161+
out_file = osp.join(self.tmpdir.name, 'results_2.png')
162+
self.vis.visualize_cls(
163+
image=image, data_sample=data_sample, out_file=out_file)
164+
self.assertTrue(osp.exists(out_file))
165+
126166
def test_visualize_image_retrieval(self):
127167
image = np.ones((10, 10, 3), np.uint8)
128168
data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1])

0 commit comments

Comments
 (0)