3838 20 : [0 , 250 , 0 ]
3939}
4040
41- # Create the custom color map
42- custom_colormap = np .zeros ((256 , 1 , 3 ), dtype = np .uint8 )
41+ def create_custom_color_map (color_map ):
42+ # Create the custom color map
43+ custom_colormap = np .zeros ((256 , 1 , 3 ), dtype = np .uint8 )
4344
44- for i in range (256 ):
45- if i in color_map :
46- custom_colormap [i , 0 , :] = color_map [i ]
47- else :
48- # If the index is not defined in the color map, set it to black
49- custom_colormap [i , 0 , :] = [0 , 0 , 0 ]
50- custom_colormap = custom_colormap [...,::- 1 ]
45+ for i in range (256 ):
46+ if str (i ) in color_map :
47+ custom_colormap [i , 0 , :] = color_map [str (i )]
48+ else :
49+ # If the index is not defined in the color map, set it to black
50+ custom_colormap [i , 0 , :] = [0 , 0 , 0 ]
51+ custom_colormap = custom_colormap [...,::- 1 ]
52+ return np .uint8 (custom_colormap )
5153
5254class AttentionModule (nn .Module ):
5355 def __init__ (self , in_channels , out_channels ):
@@ -329,6 +331,8 @@ def forward(self, x, meta_channel):
329331 x_semantics = self .decoder_semantic (x ) + 1 # offset of 1 to shift elu to ]0,inf[
330332
331333 return x_semantics
334+
335+
332336
333337
334338def build_normal_xyz (xyz , norm_factor = 0.25 , device = 'cuda' ):
@@ -407,7 +411,17 @@ def __init__(self):
407411 with open ("/home/appuser/data/config.json" ) as json_data :
408412 config = json .load (json_data )
409413 json_data .close ()
414+
415+ with open (config ["MODEL_CONFIG" ]) as json_data :
416+ model_config = json .load (json_data )
417+ json_data .close ()
418+
419+ # create color map
420+ self .color_map = create_custom_color_map (model_config ["CLASS_COLORS" ])
421+ print (self .color_map )
410422
423+ # create flag to use normals
424+ self .use_normals = model_config ["USE_NORMALS" ]
411425
412426 self .metadata_path = config ["METADATA_PATH" ]
413427
@@ -425,7 +439,10 @@ def __init__(self):
425439 self .metadata = client .SensorInfo (f .read ())
426440
427441 self .device = torch .device ("cuda" ) # if torch.cuda.is_available() else "cpu")
428- self .nocs_model = SemanticNetworkWithFPN (backbone = 'resnet34' , meta_channel_dim = 6 , num_classes = 20 , attention = True , multi_scale_meta = True )
442+ if model_config ["USE_NORMALS" ]:
443+ self .nocs_model = SemanticNetworkWithFPN (backbone = model_config ["BACKBONE" ], meta_channel_dim = 6 , num_classes = model_config ["NUM_CLASSES" ], attention = model_config ["USE_ATTENTION" ], multi_scale_meta = model_config ["USE_MULTI_SCALE" ])
444+ else :
445+ self .nocs_model = SemanticNetworkWithFPN (backbone = model_config ["BACKBONE" ], meta_channel_dim = 3 , num_classes = model_config ["NUM_CLASSES" ], attention = model_config ["USE_ATTENTION" ], multi_scale_meta = model_config ["USE_MULTI_SCALE" ])
429446 self .nocs_model .load_state_dict (torch .load (config ["MODEL_PATH" ], map_location = self .device ))
430447
431448 # Training loop
@@ -501,7 +518,10 @@ def run(self):
501518 normals = build_normal_xyz (xyz_ , norm_factor = 0.25 , device = 'cuda' )
502519 range_img = torch .norm (xyz_ , dim = 1 , keepdim = True ) #+ 1e-10
503520
504- outputs_semantic = self .nocs_model (torch .cat ([range_img , reflectivity ],axis = 1 ), torch .cat ([xyz_ , normals ],axis = 1 ))
521+ if self .use_normals :
522+ outputs_semantic = self .nocs_model (torch .cat ([range_img , reflectivity ],axis = 1 ), torch .cat ([xyz_ , normals ],axis = 1 ))
523+ else :
524+ outputs_semantic = self .nocs_model (torch .cat ([range_img , reflectivity ],axis = 1 ), xyz_ )
505525
506526 semseg_img = torch .argmax (outputs_semantic ,dim = 1 )
507527
@@ -514,7 +534,7 @@ def run(self):
514534 start_time_vis = self .get_clock ().now ()
515535 semantics_pred = (semseg_img ).permute (0 , 1 , 2 )[0 ,...].cpu ().detach ().numpy ()
516536 idx_VRUs = np .where (semantics_pred == 6 )
517- prev_sem_pred = cv2 .applyColorMap (np .uint8 (semantics_pred ), custom_colormap )
537+ prev_sem_pred = cv2 .applyColorMap (np .uint8 (semantics_pred ), self . color_map )
518538 end_time_vis = self .get_clock ().now ()
519539 self .get_logger ().info ('Cycle Time Vis: {} {}' .format (end_time_vis - start_time_vis , end_time_vis - start_time ))
520540
0 commit comments