@@ -159,18 +159,14 @@ def __call__(self, data_B, data_F, data_T):
159159 if data_B .ndim == 1 :
160160 data_B = np .array (data_B ).reshape (1 , - 1 )
161161
162- loader = get_dataloader (data_B , data_F , data_T , self .mdl .norm )
162+ _ , ts_feats , scalar_feats = get_dataloader (data_B , data_F , data_T , self .mdl .norm )
163163
164164 # 2.Validate the models
165- data_P = torch .Tensor ([]).to (self .device ) # Allocate memory to store loss density
166-
167- with torch .no_grad ():
165+ self .mdl .eval ()
166+ with torch .inference_mode ():
168167 # Start model evaluation explicitly
169- self .mdl .eval ()
170- for inputs , vars in loader :
171- Pv , h_series = self .mdl (inputs .to (self .device ), vars .to (self .device ))
168+ data_P , h_series = self .mdl (ts_feats .to (self .device ), scalar_feats .to (self .device ))
172169
173- data_P = torch .cat ((data_P , Pv .to (self .device )), dim = 0 )
174170 data_P , h_series = data_P .cpu ().numpy (), h_series .cpu ().numpy ()
175171
176172 # 3.Return results
@@ -337,9 +333,7 @@ def forward(self, x, hidden=None):
337333
338334
339335def get_dataloader (data_B , data_F , data_T , norm , n_init = 32 ):
340- """
341- Preprocess data into a data loader.
342-
336+ """Preprocess data into a data loader.
343337 Get a test dataloader.
344338
345339 Parameters
@@ -394,14 +388,14 @@ def get_dataloader(data_B, data_F, data_T, norm, n_init=32):
394388
395389 s0 = get_operator_init (in_B [:, 0 , 0 ] - in_dB [:, 0 , 0 ], in_dB , max_B , min_B ) # Operator inital state
396390
391+ ts_feats = torch .cat ((in_B , in_dB , in_dB_dt ), dim = 2 )
392+ scalar_feats = torch .cat ((in_F , in_T , s0 ), dim = 1 )
397393 # 6. Create dataloader to speed up data processing
398- test_dataset = torch .utils .data .TensorDataset (
399- torch .cat ((in_B , in_dB , in_dB_dt ), dim = 2 ), torch .cat ((in_F , in_T , s0 ), dim = 1 )
400- )
394+ test_dataset = torch .utils .data .TensorDataset (ts_feats , scalar_feats )
401395 kwargs = {"num_workers" : 0 , "batch_size" : 128 , "drop_last" : False }
402396 test_loader = torch .utils .data .DataLoader (test_dataset , ** kwargs )
403397
404- return test_loader
398+ return test_loader , ts_feats , scalar_feats
405399
406400
407401# %% Predict the operator state at t0
0 commit comments