@@ -118,6 +118,11 @@ class Trainer:
118
118
train_datasets .append (train_ds )
119
119
val_datasets .append (val_ds )
120
120
121
+ # concatenate datasets if map consistency is disabled
122
+ if not self .map_consistency :
123
+ train_datasets = [ConcatDataset (train_datasets )]
124
+ val_datasets = [ConcatDataset (val_datasets )]
125
+
121
126
# create dataloaders
122
127
train_loaders = []
123
128
for ds in train_datasets :
@@ -280,15 +285,21 @@ class Trainer:
280
285
grid_conf = self .grid_conf
281
286
device = self .cfg .device
282
287
283
- fig = plt .figure (figsize = (20 , 5 ))
284
- ax1 = fig .add_subplot (241 )
285
- ax2 = fig .add_subplot (242 , projection = '3d' )
286
- ax3 = fig .add_subplot (243 , projection = '3d' )
287
- ax4 = fig .add_subplot (244 , projection = '3d' )
288
- ax5 = fig .add_subplot (245 )
289
- ax6 = fig .add_subplot (246 )
290
- ax7 = fig .add_subplot (247 )
291
- ax8 = fig .add_subplot (248 )
288
+ fig = plt .figure (figsize = (20 , 12 ))
289
+ ax1 = fig .add_subplot (341 )
290
+ ax2 = fig .add_subplot (342 , projection = '3d' )
291
+ ax3 = fig .add_subplot (343 , projection = '3d' )
292
+ ax4 = fig .add_subplot (344 , projection = '3d' )
293
+ ax5 = fig .add_subplot (345 )
294
+ ax6 = fig .add_subplot (346 )
295
+ ax7 = fig .add_subplot (347 )
296
+ ax8 = fig .add_subplot (348 )
297
+ ax9 = fig .add_subplot (349 )
298
+ ax10 = fig .add_subplot (3 , 4 , 10 )
299
+ ax11 = fig .add_subplot (3 , 4 , 11 )
300
+ ax12 = fig .add_subplot (3 , 4 , 12 )
301
+
302
+ axes = [ax1 , ax2 , ax3 , ax4 , ax5 , ax6 , ax7 , ax8 , ax9 , ax10 , ax11 , ax12 ]
292
303
293
304
# visualize training predictions
294
305
with torch .no_grad ():
@@ -300,7 +311,7 @@ class Trainer:
300
311
height_pred_geom , height_pred_diff = model .bevencode (voxel_feats )
301
312
height_pred = height_pred_geom - height_pred_diff
302
313
303
- for ax in [ ax1 , ax2 , ax3 , ax4 , ax5 , ax6 , ax7 , ax8 ] :
314
+ for ax in axes :
304
315
ax .clear ()
305
316
306
317
# plot image
@@ -309,14 +320,16 @@ class Trainer:
309
320
310
321
# plot prediction as surface
311
322
ax2 .set_title ('Pred Surface' )
312
- height = height_pred [0 ][0 ].cpu ().numpy ().T
323
+ height_pred = height_pred [0 ][0 ].cpu ().numpy ().T
324
+ height_pred_geom = height_pred_geom [0 ][0 ].cpu ().numpy ().T
325
+ height_pred_diff = height_pred_diff [0 ][0 ].cpu ().numpy ().T
313
326
height_lidar = hm_lidar [0 ][0 ].cpu ().numpy ().T
314
327
height_traj = hm_traj [0 ][0 ].cpu ().numpy ().T
315
328
mask = hm_lidar [0 ][1 ].bool ().cpu ().numpy ().T
316
329
x_grid = np .arange (grid_conf ['xbound' ][0 ], grid_conf ['xbound' ][1 ], grid_conf ['xbound' ][2 ])
317
330
y_grid = np .arange (grid_conf ['ybound' ][0 ], grid_conf ['ybound' ][1 ], grid_conf ['ybound' ][2 ])
318
331
x_grid , y_grid = np .meshgrid (x_grid , y_grid )
319
- ax2 .plot_surface (x_grid , y_grid , height , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
332
+ ax2 .plot_surface (x_grid , y_grid , height_pred , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
320
333
ax2 .set_zlim (- 1.0 , 1.0 )
321
334
ax2 .set_xlabel ('x [m]' )
322
335
ax2 .set_ylabel ('y [m]' )
@@ -336,19 +349,39 @@ class Trainer:
336
349
337
350
# plot prediction as image
338
351
ax5 .set_title ('Prediction' )
339
- ax5 .imshow (height , origin = 'lower' , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
352
+ ax5 .imshow (height_pred , origin = 'lower' , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
340
353
341
354
ax6 .set_title ('Masked Prediction' )
342
- height_vis = np .zeros_like (height )
343
- height_vis [mask ] = height [mask ]
344
- ax6 .imshow (height_vis , origin = 'lower' , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
355
+ height_pred_vis = np .zeros_like (height_pred )
356
+ height_pred_vis [mask ] = height_pred [mask ]
357
+ ax6 .imshow (height_pred_vis , origin = 'lower' , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
345
358
346
359
ax7 .set_title ('Lidar' )
347
360
ax7 .imshow (height_lidar , origin = 'lower' , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
348
361
349
362
ax8 .set_title ('Traj' )
350
363
ax8 .imshow (height_traj , origin = 'lower' , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
351
364
365
+ # predicted geometric height
366
+ ax9 .set_title ('Geom Pred' )
367
+ ax9 .imshow (height_pred_geom , origin = 'lower' , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
368
+
369
+ # masked predicted geometric height
370
+ ax10 .set_title ('Masked Geom Pred' )
371
+ height_pred_geom_vis = np .zeros_like (height_pred )
372
+ height_pred_geom_vis [mask ] = height_pred_geom [mask ]
373
+ ax10 .imshow (height_pred_geom_vis , origin = 'lower' , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
374
+
375
+ # predicted diff height
376
+ ax11 .set_title ('Diff Pred' )
377
+ ax11 .imshow (height_pred_diff , origin = 'lower' , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
378
+
379
+ # masked predicted diff height
380
+ ax12 .set_title ('Masked Diff Pred' )
381
+ height_pred_diff_vis = np .zeros_like (height_pred )
382
+ height_pred_diff_vis [mask ] = height_pred_diff [mask ]
383
+ ax12 .imshow (height_pred_diff_vis , origin = 'lower' , cmap = 'jet' , vmin = - 1.0 , vmax = 1.0 )
384
+
352
385
return fig
353
386
354
387
0 commit comments