Skip to content

Commit 6005adb

Browse files
Update semantic_lidar_node.py
1 parent c58e402 commit 6005adb

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

ros2_ws/src/semantic_lidar_package/semantic_lidar_package/semantic_lidar_node.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,18 @@
3838
20: [0, 250, 0]
3939
}
4040

41-
# Create the custom color map
42-
custom_colormap = np.zeros((256, 1, 3), dtype=np.uint8)
41+
def create_custom_color_map(color_map):
42+
# Create the custom color map
43+
custom_colormap = np.zeros((256, 1, 3), dtype=np.uint8)
4344

44-
for i in range(256):
45-
if i in color_map:
46-
custom_colormap[i, 0, :] = color_map[i]
47-
else:
48-
# If the index is not defined in the color map, set it to black
49-
custom_colormap[i, 0, :] = [0, 0, 0]
50-
custom_colormap = custom_colormap[...,::-1]
45+
for i in range(256):
46+
if str(i) in color_map:
47+
custom_colormap[i, 0, :] = color_map[str(i)]
48+
else:
49+
# If the index is not defined in the color map, set it to black
50+
custom_colormap[i, 0, :] = [0, 0, 0]
51+
custom_colormap = custom_colormap[...,::-1]
52+
return np.uint8(custom_colormap)
5153

5254
class AttentionModule(nn.Module):
5355
def __init__(self, in_channels, out_channels):
@@ -329,6 +331,8 @@ def forward(self, x, meta_channel):
329331
x_semantics = self.decoder_semantic(x) + 1 # offset of 1 to shift elu to ]0,inf[
330332

331333
return x_semantics
334+
335+
332336

333337

334338
def build_normal_xyz(xyz, norm_factor=0.25, device='cuda'):
@@ -407,7 +411,17 @@ def __init__(self):
407411
with open("/home/appuser/data/config.json") as json_data:
408412
config = json.load(json_data)
409413
json_data.close()
414+
415+
with open(config["MODEL_CONFIG"]) as json_data:
416+
model_config = json.load(json_data)
417+
json_data.close()
418+
419+
# create color map
420+
self.color_map = create_custom_color_map(model_config["CLASS_COLORS"])
421+
print(self.color_map)
410422

423+
# create flag to use normals
424+
self.use_normals = model_config["USE_NORMALS"]
411425

412426
self.metadata_path = config["METADATA_PATH"]
413427

@@ -425,7 +439,10 @@ def __init__(self):
425439
self.metadata = client.SensorInfo(f.read())
426440

427441
self.device = torch.device("cuda") # if torch.cuda.is_available() else "cpu")
428-
self.nocs_model = SemanticNetworkWithFPN(backbone='resnet34', meta_channel_dim=6, num_classes=20, attention=True, multi_scale_meta=True)
442+
if model_config["USE_NORMALS"]:
443+
self.nocs_model = SemanticNetworkWithFPN(backbone=model_config["BACKBONE"], meta_channel_dim=6, num_classes=model_config["NUM_CLASSES"], attention=model_config["USE_ATTENTION"], multi_scale_meta=model_config["USE_MULTI_SCALE"])
444+
else:
445+
self.nocs_model = SemanticNetworkWithFPN(backbone=model_config["BACKBONE"], meta_channel_dim=3, num_classes=model_config["NUM_CLASSES"], attention=model_config["USE_ATTENTION"], multi_scale_meta=model_config["USE_MULTI_SCALE"])
429446
self.nocs_model.load_state_dict(torch.load(config["MODEL_PATH"], map_location=self.device))
430447

431448
# Training loop
@@ -501,7 +518,10 @@ def run(self):
501518
normals = build_normal_xyz(xyz_, norm_factor=0.25, device='cuda')
502519
range_img = torch.norm(xyz_, dim=1, keepdim=True) #+ 1e-10
503520

504-
outputs_semantic = self.nocs_model(torch.cat([range_img, reflectivity],axis=1), torch.cat([xyz_, normals],axis=1))
521+
if self.use_normals:
522+
outputs_semantic = self.nocs_model(torch.cat([range_img, reflectivity],axis=1), torch.cat([xyz_, normals],axis=1))
523+
else:
524+
outputs_semantic = self.nocs_model(torch.cat([range_img, reflectivity],axis=1), xyz_)
505525

506526
semseg_img = torch.argmax(outputs_semantic,dim=1)
507527

@@ -514,7 +534,7 @@ def run(self):
514534
start_time_vis= self.get_clock().now()
515535
semantics_pred = (semseg_img).permute(0, 1, 2)[0,...].cpu().detach().numpy()
516536
idx_VRUs = np.where(semantics_pred==6)
517-
prev_sem_pred = cv2.applyColorMap(np.uint8(semantics_pred), custom_colormap)
537+
prev_sem_pred = cv2.applyColorMap(np.uint8(semantics_pred), self.color_map)
518538
end_time_vis= self.get_clock().now()
519539
self.get_logger().info('Cycle Time Vis: {} {}'.format(end_time_vis-start_time_vis, end_time_vis-start_time))
520540

0 commit comments

Comments
 (0)