@@ -333,11 +333,14 @@ class EvaluatorGeom:
333
333
334
334
return loss_trans_sum , loss_rot_sum
335
335
336
- def evaluate (self , vis = False , verbose = False ):
336
+ def evaluate (self , vis = False , verbose = False , random_order = False ):
337
337
for path in self .data_paths :
338
338
print (f'Evaluation on { os .path .basename (path )} ...' )
339
339
ds = self .DataClass (path , is_train = False , data_aug_conf = self .data_aug_conf , dphys_cfg = self .dphys_cfg )
340
- for i in tqdm (range (len (ds ))):
340
+ sample_range = np .arange (0 , len (ds ))
341
+ if random_order :
342
+ np .random .shuffle (sample_range )
343
+ for i in tqdm (sample_range ):
341
344
states_true , height = self .get_data (i , ds )
342
345
trans_diff , rot_diff = self .eval_diff_physics (height , states_true , vis = vis )
343
346
if rot_diff is not None :
@@ -374,7 +377,8 @@ class EvaluatorGeom:
374
377
class EvaluatorLSS (EvaluatorGeom ):
375
378
def __init__ (self , dphys_config_path , terrain_encoder_config_path , dataset ,
376
379
# model_path='../config/tb_runs/lss_2024_03_04_09_42_47/train_lss.pt'):
377
- model_path = '../config/tb_runs/lss_rellis3d_robingas_dphysics_terrain/train_lss.pt' ):
380
+ # model_path='../config/tb_runs/lss_rellis3d_robingas_dphysics_terrain/train_lss.pt'):
381
+ model_path = '../config/tb_runs/lss_rellis3d_2024_03_06_16_07_52/train_lss.pt' ):
378
382
super ().__init__ (dphys_config_path , terrain_encoder_config_path , dataset , model_path )
379
383
380
384
def load_model (self ):
@@ -576,8 +580,8 @@ def evaluate(dataset):
576
580
577
581
578
582
def main ():
579
- dataset = 'robingas'
580
- # dataset = 'rellis3d'
583
+ # dataset = 'robingas'
584
+ dataset = 'rellis3d'
581
585
582
586
# vis_data(dataset)
583
587
evaluate (dataset )
0 commit comments