@@ -36,14 +36,20 @@ def process_model_config(
36
36
cfg = copy .deepcopy (model_cfg )
37
37
test_pipeline = cfg .test_dataloader .dataset .pipeline
38
38
data_preprocessor = cfg .model .data_preprocessor
39
- codec = cfg .codec
40
- if isinstance (codec , list ):
41
- codec = codec [- 1 ]
42
- input_size = codec ['input_size' ] if input_shape is None else input_shape
39
+ codec = getattr (cfg , 'codec' , None )
40
+ if codec is not None :
41
+ if isinstance (codec , list ):
42
+ codec = codec [- 1 ]
43
+ input_size = codec ['input_size' ] if input_shape is None \
44
+ else input_shape
45
+ else :
46
+ input_size = cfg .img_scale
47
+
43
48
test_pipeline [0 ] = dict (type = 'LoadImageFromFile' )
44
49
for i in reversed (range (len (test_pipeline ))):
45
50
trans = test_pipeline [i ]
46
- if trans ['type' ] == 'PackPoseInputs' :
51
+ if trans ['type' ] == 'PackPoseInputs' or trans [
52
+ 'type' ] == 'PackDetPoseInputs' :
47
53
test_pipeline .pop (i )
48
54
elif trans ['type' ] == 'GetBBoxCenterScale' :
49
55
trans ['type' ] = 'TopDownGetBboxCenterScale'
@@ -53,22 +59,37 @@ def process_model_config(
53
59
trans ['type' ] = 'TopDownAffine'
54
60
trans ['image_size' ] = input_size
55
61
trans .pop ('input_size' )
56
-
57
- test_pipeline .append (
58
- dict (
59
- type = 'Normalize' ,
60
- mean = data_preprocessor .mean ,
61
- std = data_preprocessor .std ,
62
- to_rgb = data_preprocessor .bgr_to_rgb ))
62
+ elif trans ['type' ][:6 ] == 'mmdet.' :
63
+ trans ['type' ] = trans ['type' ][6 :]
64
+
65
+ # DetDataPreprocessor does not have mean, std, bgr_to_rgb
66
+ # TODO: implement PoseToDetConverter and PackDetPoseInputs in c++
67
+ if data_preprocessor .type != 'mmdet.DetDataPreprocessor' :
68
+ test_pipeline .append (
69
+ dict (
70
+ type = 'Normalize' ,
71
+ mean = data_preprocessor .mean ,
72
+ std = data_preprocessor .std ,
73
+ to_rgb = data_preprocessor .bgr_to_rgb ))
63
74
test_pipeline .append (dict (type = 'ImageToTensor' , keys = ['img' ]))
64
- test_pipeline .append (
65
- dict (
66
- type = 'Collect' ,
67
- keys = ['img' ],
68
- meta_keys = [
69
- 'img_shape' , 'pad_shape' , 'ori_shape' , 'img_norm_cfg' ,
70
- 'scale_factor' , 'bbox_score' , 'center' , 'scale'
71
- ]))
75
+ if data_preprocessor .type != 'mmdet.DetDataPreprocessor' :
76
+ test_pipeline .append (
77
+ dict (
78
+ type = 'Collect' ,
79
+ keys = ['img' ],
80
+ meta_keys = [
81
+ 'img_shape' , 'pad_shape' , 'ori_shape' , 'img_norm_cfg' ,
82
+ 'scale_factor' , 'bbox_score' , 'center' , 'scale'
83
+ ]))
84
+ else :
85
+ test_pipeline .append (
86
+ dict (
87
+ type = 'Collect' ,
88
+ keys = ['img' ],
89
+ meta_keys = [
90
+ 'id' , 'img_id' , 'img_path' , 'ori_shape' , 'img_shape' ,
91
+ 'scale_factor' , 'flip_indices'
92
+ ]))
72
93
73
94
cfg .test_dataloader .dataset .pipeline = test_pipeline
74
95
return cfg
@@ -345,13 +366,19 @@ def get_preprocess(self, *args, **kwargs) -> Dict:
345
366
346
367
def get_postprocess (self , * args , ** kwargs ) -> Dict :
347
368
"""Get the postprocess information for SDK."""
348
- codec = self .model_cfg .codec
349
- if isinstance (codec , (list , tuple )):
350
- codec = codec [- 1 ]
351
- component = 'UNKNOWN'
369
+ codec = getattr (self .model_cfg , 'codec' , None )
352
370
params = copy .deepcopy (self .model_cfg .model .test_cfg )
353
- params .update (codec )
354
- if self .model_cfg .model .type == 'TopdownPoseEstimator' :
371
+ component = 'UNKNOWN'
372
+ if codec is not None :
373
+ if isinstance (codec , (list , tuple )):
374
+ codec = codec [- 1 ]
375
+ params .update (codec )
376
+ else :
377
+ # TODO: implement this in c++
378
+ component = 'YOLOXPoseHeadDecode'
379
+
380
+ if self .model_cfg .model .type == 'TopdownPoseEstimator' \
381
+ and codec is not None :
355
382
component = 'TopdownHeatmapSimpleHeadDecode'
356
383
if codec .type == 'MSRAHeatmap' :
357
384
params ['post_process' ] = 'default'
0 commit comments