16
16
from monoforce .datasets .rough import ROUGH , rough_seq_paths
17
17
from monoforce .models .terrain_encoder .utils import ego_to_cam , get_only_in_img_mask , denormalize_img
18
18
from monoforce .utils import read_yaml , write_to_csv , append_to_csv
19
+ from monoforce .losses import physics_loss , hm_loss
19
20
import matplotlib as mpl
20
21
21
22
@@ -80,45 +81,6 @@ def __init__(self,
80
81
# self.ds = Fusion(path=self.path, lss_cfg=self.lss_config, dphys_cfg=self.dphys_cfg, is_train=False)
81
82
self .loader = torch .utils .data .DataLoader (self .ds , batch_size = 1 , shuffle = False )
82
83
83
- def hm_loss (self , height_pred , height_gt , weights = None ):
84
- assert height_pred .shape == height_gt .shape , 'Height prediction and ground truth must have the same shape'
85
- if weights is None :
86
- weights = torch .ones_like (height_gt )
87
- assert weights .shape == height_gt .shape , 'Weights and height ground truth must have the same shape'
88
-
89
- # remove nan values
90
- mask_valid = ~ torch .isnan (height_gt )
91
- height_gt = height_gt [mask_valid ]
92
- height_pred = height_pred [mask_valid ]
93
- weights = weights [mask_valid ]
94
-
95
- # compute weighted loss
96
- loss = torch .nn .functional .mse_loss (height_pred * weights , height_gt * weights , reduction = 'mean' )
97
- assert not torch .isnan (loss ), 'Terrain Loss is nan'
98
-
99
- return loss
100
-
101
- def physics_loss (self , states_pred , states_gt , pred_ts , gt_ts ):
102
- # unpack the states
103
- X , Xd , R , Omega = states_gt
104
- X_pred , Xd_pred , R_pred , Omega_pred = states_pred
105
-
106
- # find the closest timesteps in the trajectory to the ground truth timesteps
107
- ts_ids = torch .argmin (torch .abs (pred_ts .unsqueeze (1 ) - gt_ts .unsqueeze (2 )), dim = 2 )
108
-
109
- # get the predicted states at the closest timesteps to the ground truth timesteps
110
- batch_size = X .shape [0 ]
111
- X_pred_gt_ts = X_pred [torch .arange (batch_size ).unsqueeze (1 ), ts_ids ]
112
-
113
- # remove nan values
114
- mask_valid = ~ torch .isnan (X_pred_gt_ts )
115
- X_pred_gt_ts = X_pred_gt_ts [mask_valid ]
116
- X = X [mask_valid ]
117
- loss = torch .nn .functional .mse_loss (X_pred_gt_ts , X )
118
- assert not torch .isnan (loss ), 'Physics Loss is nan'
119
-
120
- return loss
121
-
122
84
def run (self , vis = False , save = False ):
123
85
if save :
124
86
# create output folder
@@ -168,12 +130,12 @@ def run(self, vis=False, save=False):
168
130
# friction_pred = torch.ones_like(terrain_pred)
169
131
170
132
# evaluation losses
171
- terrain_loss = self . hm_loss (height_pred = terrain_pred [0 , 0 ], height_gt = hm_terrain [0 , 0 ])
133
+ loss_terrain = hm_loss (height_pred = terrain_pred [0 , 0 ], height_gt = hm_terrain [0 , 0 ], weights = hm_terrain [ 0 , 1 ])
172
134
states_gt = [Xs , Xds , Rs , Omegas ]
173
135
state0 = tuple ([s [:, 0 ] for s in states_gt ])
174
136
states_pred , _ = self .dphysics (z_grid = terrain_pred .squeeze (1 ), state = state0 ,
175
137
controls = controls , friction = friction_pred .squeeze (1 ))
176
- physics_loss = self . physics_loss (states_pred , states_gt , pred_ts = control_ts , gt_ts = traj_ts )
138
+ loss_physics = physics_loss (states_pred = states_pred , states_gt = states_gt , pred_ts = control_ts , gt_ts = traj_ts )
177
139
178
140
# visualizations
179
141
terrain_pred = terrain_pred [0 , 0 ].cpu ()
@@ -187,7 +149,7 @@ def run(self, vis=False, save=False):
187
149
# hm_points = hm_points[:, terrain_mask]
188
150
189
151
plt .clf ()
190
- plt .suptitle (f'Terrain Loss: { terrain_loss .item ():.4f} , Physics Loss: { physics_loss .item ():.4f} ' )
152
+ plt .suptitle (f'Terrain Loss: { loss_terrain .item ():.4f} , Physics Loss: { loss_physics .item ():.4f} ' )
191
153
for imgi , img in enumerate (imgs [0 ]):
192
154
cam_pts = ego_to_cam (hm_points , rots [0 , imgi ], trans [0 , imgi ], intrins [0 , imgi ])
193
155
mask = get_only_in_img_mask (cam_pts , H , W )
@@ -251,7 +213,7 @@ def run(self, vis=False, save=False):
251
213
if save :
252
214
plt .savefig (f'{ self .output_folder } /{ i :04d} .png' )
253
215
append_to_csv (f'{ self .output_folder } /losses.csv' ,
254
- f'{ i :04d} .png, { terrain_loss .item ():.4f} ,{ physics_loss .item ():.4f} \n ' )
216
+ f'{ i :04d} .png, { loss_terrain .item ():.4f} ,{ loss_physics .item ():.4f} \n ' )
255
217
256
218
plt .close (fig )
257
219
0 commit comments