@@ -14,7 +14,7 @@ from monoforce.datasets.data import TravData, explore_data
14
14
from monoforce .config import Config
15
15
from monoforce .utils import read_yaml , write_to_yaml , normalize
16
16
from monoforce .datasets import seq_paths , sim_seq_paths
17
- from monoforce .losses import RMSE
17
+ from monoforce .losses import RMSE , total_variation
18
18
from tqdm import tqdm
19
19
from torch .utils .tensorboard import SummaryWriter
20
20
from datetime import datetime
@@ -222,7 +222,11 @@ class Trainer:
222
222
loss_geom = self .lidar_hm_loss (height_pred_geom , height_lidar , weights_lidar )
223
223
loss_rigid = self .traj_hm_loss (height_pred_rigid , height_traj , weights_traj )
224
224
225
- loss = loss_geom + 100. * loss_rigid
225
+ # add height difference loss
226
+ # loss_hdiff = height_pred_diff.abs().mean()
227
+ loss_hdiff = total_variation (height_pred_diff )
228
+
229
+ loss = loss_geom + 100. * loss_rigid + 0.1 * loss_hdiff
226
230
if self .map_consistency and len (height_pred_rigid ) > 1 :
227
231
loss_map = self .map_consistency_loss (height_pred_rigid , map_pose )
228
232
loss += 0.1 * loss_map
@@ -239,6 +243,7 @@ class Trainer:
239
243
counter += 1
240
244
self .writer .add_scalar (f"{ 'train' if train else 'val' } /iter_loss_geom" , loss_geom , counter )
241
245
self .writer .add_scalar (f"{ 'train' if train else 'val' } /iter_loss_rigid" , loss_rigid , counter )
246
+ self .writer .add_scalar (f"{ 'train' if train else 'val' } /iter_loss_hdiff" , loss_hdiff , counter )
242
247
self .writer .add_scalar (f"{ 'train' if train else 'val' } /iter_loss_map" , loss_map , counter )
243
248
self .writer .add_scalar (f"{ 'train' if train else 'val' } /iter_loss" , loss , counter )
244
249
0 commit comments