11
11
from mmengine .visualization .utils import img_from_canvas
12
12
13
13
from mmpretrain .registry import VISUALIZERS
14
- from mmpretrain .structures import DataSample
14
+ from mmpretrain .structures import DataSample , MultiTaskDataSample
15
15
from .utils import create_figure , get_adaptive_scale
16
16
17
17
@@ -99,6 +99,67 @@ def visualize_cls(self,
99
99
Returns:
100
100
np.ndarray: The visualization image.
101
101
"""
102
+
103
+ def _draw_gt (data_sample : DataSample ,
104
+ classes : Optional [Sequence [str ]],
105
+ draw_gt : bool ,
106
+ texts : Sequence [str ],
107
+ parent_task : str = '' ):
108
+ if isinstance (data_sample , MultiTaskDataSample ):
109
+ for task in data_sample .tasks :
110
+ sub_task = f'{ parent_task } _{ task } ' if parent_task else task
111
+ _draw_gt (
112
+ data_sample .get (task ), classes , draw_gt , texts ,
113
+ sub_task )
114
+ else :
115
+ if draw_gt and 'gt_label' in data_sample :
116
+ idx = data_sample .gt_label .tolist ()
117
+ class_labels = ['' ] * len (idx )
118
+ if classes is not None :
119
+ class_labels = [f' ({ classes [i ]} )' for i in idx ]
120
+ labels = [
121
+ str (idx [i ]) + class_labels [i ] for i in range (len (idx ))
122
+ ]
123
+ prefix = f'{ parent_task } Ground truth: ' if parent_task \
124
+ else 'Ground truth: '
125
+ texts .append (prefix +
126
+ ('\n ' + ' ' * len (prefix )).join (labels ))
127
+
128
+ def _draw_pred (data_sample : DataSample ,
129
+ classes : Optional [Sequence [str ]],
130
+ draw_pred : bool ,
131
+ draw_score : bool ,
132
+ texts : Sequence [str ],
133
+ parent_task : str = '' ):
134
+ if isinstance (data_sample , MultiTaskDataSample ):
135
+ for task in data_sample .tasks :
136
+ sub_task = f'{ parent_task } _{ task } ' if parent_task else task
137
+ _draw_pred (
138
+ data_sample .get (task ), classes , draw_pred , draw_score ,
139
+ texts , sub_task )
140
+ else :
141
+ if draw_pred and 'pred_label' in data_sample :
142
+ idx = data_sample .pred_label .tolist ()
143
+ score_labels = ['' ] * len (idx )
144
+ class_labels = ['' ] * len (idx )
145
+ if draw_score and 'pred_score' in data_sample :
146
+ score_labels = [
147
+ f', { data_sample .pred_score [i ].item ():.2f} '
148
+ for i in idx
149
+ ]
150
+
151
+ if classes is not None :
152
+ class_labels = [f' ({ classes [i ]} )' for i in idx ]
153
+
154
+ labels = [
155
+ str (idx [i ]) + score_labels [i ] + class_labels [i ]
156
+ for i in range (len (idx ))
157
+ ]
158
+ prefix = f'{ parent_task } Prediction: ' if parent_task \
159
+ else 'Prediction: '
160
+ texts .append (prefix +
161
+ ('\n ' + ' ' * len (prefix )).join (labels ))
162
+
102
163
if self .dataset_meta is not None :
103
164
classes = classes or self .dataset_meta .get ('classes' , None )
104
165
@@ -114,33 +175,9 @@ def visualize_cls(self,
114
175
texts = []
115
176
self .set_image (image )
116
177
117
- if draw_gt and 'gt_label' in data_sample :
118
- idx = data_sample .gt_label .tolist ()
119
- class_labels = ['' ] * len (idx )
120
- if classes is not None :
121
- class_labels = [f' ({ classes [i ]} )' for i in idx ]
122
- labels = [str (idx [i ]) + class_labels [i ] for i in range (len (idx ))]
123
- prefix = 'Ground truth: '
124
- texts .append (prefix + ('\n ' + ' ' * len (prefix )).join (labels ))
125
-
126
- if draw_pred and 'pred_label' in data_sample :
127
- idx = data_sample .pred_label .tolist ()
128
- score_labels = ['' ] * len (idx )
129
- class_labels = ['' ] * len (idx )
130
- if draw_score and 'pred_score' in data_sample :
131
- score_labels = [
132
- f', { data_sample .pred_score [i ].item ():.2f} ' for i in idx
133
- ]
134
-
135
- if classes is not None :
136
- class_labels = [f' ({ classes [i ]} )' for i in idx ]
137
-
138
- labels = [
139
- str (idx [i ]) + score_labels [i ] + class_labels [i ]
140
- for i in range (len (idx ))
141
- ]
142
- prefix = 'Prediction: '
143
- texts .append (prefix + ('\n ' + ' ' * len (prefix )).join (labels ))
178
+ _draw_gt (data_sample , classes , draw_gt , texts )
179
+
180
+ _draw_pred (data_sample , classes , draw_pred , draw_score , texts )
144
181
145
182
img_scale = get_adaptive_scale (image .shape [:2 ])
146
183
text_cfg = {
0 commit comments