Skip to content

Commit 5d70f30

Browse files
committed
concat datasets if map consistency is disabled, more detailed tb logging
1 parent 1123aa0 commit 5d70f30

File tree

1 file changed

+49
-16
lines changed

1 file changed

+49
-16
lines changed

scripts/train_lss

+49-16
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ class Trainer:
118118
train_datasets.append(train_ds)
119119
val_datasets.append(val_ds)
120120

121+
# concatenate datasets if map consistency is disabled
122+
if not self.map_consistency:
123+
train_datasets = [ConcatDataset(train_datasets)]
124+
val_datasets = [ConcatDataset(val_datasets)]
125+
121126
# create dataloaders
122127
train_loaders = []
123128
for ds in train_datasets:
@@ -280,15 +285,21 @@ class Trainer:
280285
grid_conf = self.grid_conf
281286
device = self.cfg.device
282287

283-
fig = plt.figure(figsize=(20, 5))
284-
ax1 = fig.add_subplot(241)
285-
ax2 = fig.add_subplot(242, projection='3d')
286-
ax3 = fig.add_subplot(243, projection='3d')
287-
ax4 = fig.add_subplot(244, projection='3d')
288-
ax5 = fig.add_subplot(245)
289-
ax6 = fig.add_subplot(246)
290-
ax7 = fig.add_subplot(247)
291-
ax8 = fig.add_subplot(248)
288+
fig = plt.figure(figsize=(20, 12))
289+
ax1 = fig.add_subplot(341)
290+
ax2 = fig.add_subplot(342, projection='3d')
291+
ax3 = fig.add_subplot(343, projection='3d')
292+
ax4 = fig.add_subplot(344, projection='3d')
293+
ax5 = fig.add_subplot(345)
294+
ax6 = fig.add_subplot(346)
295+
ax7 = fig.add_subplot(347)
296+
ax8 = fig.add_subplot(348)
297+
ax9 = fig.add_subplot(349)
298+
ax10 = fig.add_subplot(3, 4, 10)
299+
ax11 = fig.add_subplot(3, 4, 11)
300+
ax12 = fig.add_subplot(3, 4, 12)
301+
302+
axes = [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9, ax10, ax11, ax12]
292303

293304
# visualize training predictions
294305
with torch.no_grad():
@@ -300,7 +311,7 @@ class Trainer:
300311
height_pred_geom, height_pred_diff = model.bevencode(voxel_feats)
301312
height_pred = height_pred_geom - height_pred_diff
302313

303-
for ax in [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8]:
314+
for ax in axes:
304315
ax.clear()
305316

306317
# plot image
@@ -309,14 +320,16 @@ class Trainer:
309320

310321
# plot prediction as surface
311322
ax2.set_title('Pred Surface')
312-
height = height_pred[0][0].cpu().numpy().T
323+
height_pred = height_pred[0][0].cpu().numpy().T
324+
height_pred_geom = height_pred_geom[0][0].cpu().numpy().T
325+
height_pred_diff = height_pred_diff[0][0].cpu().numpy().T
313326
height_lidar = hm_lidar[0][0].cpu().numpy().T
314327
height_traj = hm_traj[0][0].cpu().numpy().T
315328
mask = hm_lidar[0][1].bool().cpu().numpy().T
316329
x_grid = np.arange(grid_conf['xbound'][0], grid_conf['xbound'][1], grid_conf['xbound'][2])
317330
y_grid = np.arange(grid_conf['ybound'][0], grid_conf['ybound'][1], grid_conf['ybound'][2])
318331
x_grid, y_grid = np.meshgrid(x_grid, y_grid)
319-
ax2.plot_surface(x_grid, y_grid, height, cmap='jet', vmin=-1.0, vmax=1.0)
332+
ax2.plot_surface(x_grid, y_grid, height_pred, cmap='jet', vmin=-1.0, vmax=1.0)
320333
ax2.set_zlim(-1.0, 1.0)
321334
ax2.set_xlabel('x [m]')
322335
ax2.set_ylabel('y [m]')
@@ -336,19 +349,39 @@ class Trainer:
336349

337350
# plot prediction as image
338351
ax5.set_title('Prediction')
339-
ax5.imshow(height, origin='lower', cmap='jet', vmin=-1.0, vmax=1.0)
352+
ax5.imshow(height_pred, origin='lower', cmap='jet', vmin=-1.0, vmax=1.0)
340353

341354
ax6.set_title('Masked Prediction')
342-
height_vis = np.zeros_like(height)
343-
height_vis[mask] = height[mask]
344-
ax6.imshow(height_vis, origin='lower', cmap='jet', vmin=-1.0, vmax=1.0)
355+
height_pred_vis = np.zeros_like(height_pred)
356+
height_pred_vis[mask] = height_pred[mask]
357+
ax6.imshow(height_pred_vis, origin='lower', cmap='jet', vmin=-1.0, vmax=1.0)
345358

346359
ax7.set_title('Lidar')
347360
ax7.imshow(height_lidar, origin='lower', cmap='jet', vmin=-1.0, vmax=1.0)
348361

349362
ax8.set_title('Traj')
350363
ax8.imshow(height_traj, origin='lower', cmap='jet', vmin=-1.0, vmax=1.0)
351364

365+
# predicted geometric height
366+
ax9.set_title('Geom Pred')
367+
ax9.imshow(height_pred_geom, origin='lower', cmap='jet', vmin=-1.0, vmax=1.0)
368+
369+
# masked predicted geometric height
370+
ax10.set_title('Masked Geom Pred')
371+
height_pred_geom_vis = np.zeros_like(height_pred)
372+
height_pred_geom_vis[mask] = height_pred_geom[mask]
373+
ax10.imshow(height_pred_geom_vis, origin='lower', cmap='jet', vmin=-1.0, vmax=1.0)
374+
375+
# predicted diff height
376+
ax11.set_title('Diff Pred')
377+
ax11.imshow(height_pred_diff, origin='lower', cmap='jet', vmin=-1.0, vmax=1.0)
378+
379+
# masked predicted diff height
380+
ax12.set_title('Masked Diff Pred')
381+
height_pred_diff_vis = np.zeros_like(height_pred)
382+
height_pred_diff_vis[mask] = height_pred_diff[mask]
383+
ax12.imshow(height_pred_diff_vis, origin='lower', cmap='jet', vmin=-1.0, vmax=1.0)
384+
352385
return fig
353386

354387

0 commit comments

Comments
 (0)