77from time import time
88
99import postprocessing_data as pp
10- from pathlib import Path
1110from executorch .runtime import Verification , Runtime
1211from inference_tools .loop_tools import loop_inference , get_exec_time
1312from io_adapter import IOAdapter
@@ -158,21 +157,21 @@ def load_model(model_path):
158157 model_path ,
159158 verification = Verification .Minimal ,
160159 )
161- return program .load_method (" forward" )
160+ return program .load_method (' forward' )
162161
163162
164163def inference_executorch (net , num_iterations , get_slice , input_name , test_duration ):
165164 predictions = None
166165 time_infer = []
167166 if num_iterations == 1 :
168- t0 = time ()
169- slice_input = get_slice ()
170- predictions = net .execute ((slice_input [input_name ],))
171- t1 = time ()
172- time_infer .append (t1 - t0 )
167+ t0 = time ()
168+ slice_input = get_slice ()
169+ predictions = net .execute ((slice_input [input_name ],))
170+ t1 = time ()
171+ time_infer .append (t1 - t0 )
173172 else :
174- loop_results = loop_inference (num_iterations , test_duration )(inference_iteration )(get_slice , input_name , net )
175- time_infer = loop_results ['time_infer' ]
173+ loop_results = loop_inference (num_iterations , test_duration )(inference_iteration )(get_slice , input_name , net )
174+ time_infer = loop_results ['time_infer' ]
176175 return predictions , time_infer
177176
178177
@@ -190,12 +189,12 @@ def infer_slice(input_name, net, slice_input):
190189
191190def prepare_output (result , output_names , task_type ):
192191 if task_type in ['feedforward' ]:
193- return {}
192+ return {}
194193 elif task_type in ['classification' ]:
195- log .info ('Converting output tensor to print results' )
196- return {output_names [0 ]: result [0 ].numpy ()}
194+ log .info ('Converting output tensor to print results' )
195+ return {output_names [0 ]: result [0 ].numpy ()}
197196 else :
198- raise ValueError (f'Unsupported task { task_type } to print inference results' )
197+ raise ValueError (f'Unsupported task { task_type } to print inference results' )
199198
200199
201200def main ():
@@ -241,4 +240,4 @@ def main():
241240
242241
243242if __name__ == '__main__' :
244- sys .exit (main () or 0 )
243+ sys .exit (main () or 0 )
0 commit comments