11
11
from mmengine .visualization .utils import img_from_canvas
12
12
13
13
from mmpretrain .registry import VISUALIZERS
14
- from mmpretrain .structures import DataSample
14
+ from mmpretrain .structures import DataSample , MultiTaskDataSample
15
15
from .utils import create_figure , get_adaptive_scale
16
16
17
17
@@ -114,33 +114,9 @@ def visualize_cls(self,
114
114
texts = []
115
115
self .set_image (image )
116
116
117
- if draw_gt and 'gt_label' in data_sample :
118
- idx = data_sample .gt_label .tolist ()
119
- class_labels = ['' ] * len (idx )
120
- if classes is not None :
121
- class_labels = [f' ({ classes [i ]} )' for i in idx ]
122
- labels = [str (idx [i ]) + class_labels [i ] for i in range (len (idx ))]
123
- prefix = 'Ground truth: '
124
- texts .append (prefix + ('\n ' + ' ' * len (prefix )).join (labels ))
125
-
126
- if draw_pred and 'pred_label' in data_sample :
127
- idx = data_sample .pred_label .tolist ()
128
- score_labels = ['' ] * len (idx )
129
- class_labels = ['' ] * len (idx )
130
- if draw_score and 'pred_score' in data_sample :
131
- score_labels = [
132
- f', { data_sample .pred_score [i ].item ():.2f} ' for i in idx
133
- ]
134
-
135
- if classes is not None :
136
- class_labels = [f' ({ classes [i ]} )' for i in idx ]
117
+ self .draw_gt (data_sample , classes , draw_gt , texts )
137
118
138
- labels = [
139
- str (idx [i ]) + score_labels [i ] + class_labels [i ]
140
- for i in range (len (idx ))
141
- ]
142
- prefix = 'Prediction: '
143
- texts .append (prefix + ('\n ' + ' ' * len (prefix )).join (labels ))
119
+ self .draw_pred (data_sample , classes , draw_pred , draw_score , texts )
144
120
145
121
img_scale = get_adaptive_scale (image .shape [:2 ])
146
122
text_cfg = {
@@ -167,6 +143,65 @@ def visualize_cls(self,
167
143
168
144
return drawn_img
169
145
146
+ def draw_pred (self ,
147
+ data_sample : DataSample ,
148
+ classes : Optional [Sequence [str ]],
149
+ draw_pred : bool ,
150
+ draw_score : bool ,
151
+ texts : Sequence [str ],
152
+ parent_task : str = '' ):
153
+ if isinstance (data_sample , MultiTaskDataSample ):
154
+ for task in data_sample .tasks :
155
+ sub_task = f'{ parent_task } _{ task } ' if parent_task else task
156
+ self .draw_pred (
157
+ data_sample .get (task ), classes , draw_pred , draw_score ,
158
+ texts , sub_task )
159
+ else :
160
+ if draw_pred and 'pred_label' in data_sample :
161
+ idx = data_sample .pred_label .tolist ()
162
+ score_labels = ['' ] * len (idx )
163
+ class_labels = ['' ] * len (idx )
164
+ if draw_score and 'pred_score' in data_sample :
165
+ score_labels = [
166
+ f', { data_sample .pred_score [i ].item ():.2f} '
167
+ for i in idx
168
+ ]
169
+
170
+ if classes is not None :
171
+ class_labels = [f' ({ classes [i ]} )' for i in idx ]
172
+
173
+ labels = [
174
+ str (idx [i ]) + score_labels [i ] + class_labels [i ]
175
+ for i in range (len (idx ))
176
+ ]
177
+ prefix = f'{ parent_task } Prediction: ' if parent_task \
178
+ else 'Prediction: '
179
+ texts .append (prefix + ('\n ' + ' ' * len (prefix )).join (labels ))
180
+
181
+ def draw_gt (self ,
182
+ data_sample : DataSample ,
183
+ classes : Optional [Sequence [str ]],
184
+ draw_gt : bool ,
185
+ texts : Sequence [str ],
186
+ parent_task : str = '' ):
187
+ if isinstance (data_sample , MultiTaskDataSample ):
188
+ for task in data_sample .tasks :
189
+ sub_task = f'{ parent_task } _{ task } ' if parent_task else task
190
+ self .draw_gt (
191
+ data_sample .get (task ), classes , draw_gt , texts , sub_task )
192
+ else :
193
+ if draw_gt and 'gt_label' in data_sample :
194
+ idx = data_sample .gt_label .tolist ()
195
+ class_labels = ['' ] * len (idx )
196
+ if classes is not None :
197
+ class_labels = [f' ({ classes [i ]} )' for i in idx ]
198
+ labels = [
199
+ str (idx [i ]) + class_labels [i ] for i in range (len (idx ))
200
+ ]
201
+ prefix = f'{ parent_task } Ground truth: ' if parent_task \
202
+ else 'Ground truth: '
203
+ texts .append (prefix + ('\n ' + ' ' * len (prefix )).join (labels ))
204
+
170
205
@master_only
171
206
def visualize_image_retrieval (self ,
172
207
image : np .ndarray ,
0 commit comments