Skip to content

Commit 58300ca

Browse files
committed
lss train: height map difference penalty
1 parent fb72298 commit 58300ca

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

Diff for: scripts/train_lss

+7-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ from monoforce.datasets.data import TravData, explore_data
1414
from monoforce.config import Config
1515
from monoforce.utils import read_yaml, write_to_yaml, normalize
1616
from monoforce.datasets import seq_paths, sim_seq_paths
17-
from monoforce.losses import RMSE
17+
from monoforce.losses import RMSE, total_variation
1818
from tqdm import tqdm
1919
from torch.utils.tensorboard import SummaryWriter
2020
from datetime import datetime
@@ -222,7 +222,11 @@ class Trainer:
222222
loss_geom = self.lidar_hm_loss(height_pred_geom, height_lidar, weights_lidar)
223223
loss_rigid = self.traj_hm_loss(height_pred_rigid, height_traj, weights_traj)
224224

225-
loss = loss_geom + 100.*loss_rigid
225+
# add height difference loss
226+
# loss_hdiff = height_pred_diff.abs().mean()
227+
loss_hdiff = total_variation(height_pred_diff)
228+
229+
loss = loss_geom + 100.*loss_rigid + 0.1*loss_hdiff
226230
if self.map_consistency and len(height_pred_rigid) > 1:
227231
loss_map = self.map_consistency_loss(height_pred_rigid, map_pose)
228232
loss += 0.1*loss_map
@@ -239,6 +243,7 @@ class Trainer:
239243
counter += 1
240244
self.writer.add_scalar(f"{'train' if train else 'val'}/iter_loss_geom", loss_geom, counter)
241245
self.writer.add_scalar(f"{'train' if train else 'val'}/iter_loss_rigid", loss_rigid, counter)
246+
self.writer.add_scalar(f"{'train' if train else 'val'}/iter_loss_hdiff", loss_hdiff, counter)
242247
self.writer.add_scalar(f"{'train' if train else 'val'}/iter_loss_map", loss_map, counter)
243248
self.writer.add_scalar(f"{'train' if train else 'val'}/iter_loss", loss, counter)
244249

0 commit comments

Comments
 (0)