Skip to content

Commit 76b0de0

Browse files
committed
codestyle 1
1 parent c2b16cf commit 76b0de0

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

src/inference/inference_executorch.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from time import time
88

99
import postprocessing_data as pp
10-
from pathlib import Path
1110
from executorch.runtime import Verification, Runtime
1211
from inference_tools.loop_tools import loop_inference, get_exec_time
1312
from io_adapter import IOAdapter
@@ -158,21 +157,21 @@ def load_model(model_path):
158157
model_path,
159158
verification=Verification.Minimal,
160159
)
161-
return program.load_method("forward")
160+
return program.load_method('forward')
162161

163162

164163
def inference_executorch(net, num_iterations, get_slice, input_name, test_duration):
165164
predictions = None
166165
time_infer = []
167166
if num_iterations == 1:
168-
t0 = time()
169-
slice_input = get_slice()
170-
predictions = net.execute((slice_input[input_name],))
171-
t1 = time()
172-
time_infer.append(t1 - t0)
167+
t0 = time()
168+
slice_input = get_slice()
169+
predictions = net.execute((slice_input[input_name],))
170+
t1 = time()
171+
time_infer.append(t1 - t0)
173172
else:
174-
loop_results = loop_inference(num_iterations, test_duration)(inference_iteration)(get_slice, input_name, net)
175-
time_infer = loop_results['time_infer']
173+
loop_results = loop_inference(num_iterations, test_duration)(inference_iteration)(get_slice, input_name, net)
174+
time_infer = loop_results['time_infer']
176175
return predictions, time_infer
177176

178177

@@ -190,12 +189,12 @@ def infer_slice(input_name, net, slice_input):
190189

191190
def prepare_output(result, output_names, task_type):
192191
if task_type in ['feedforward']:
193-
return {}
192+
return {}
194193
elif task_type in ['classification']:
195-
log.info('Converting output tensor to print results')
196-
return {output_names[0]: result[0].numpy()}
194+
log.info('Converting output tensor to print results')
195+
return {output_names[0]: result[0].numpy()}
197196
else:
198-
raise ValueError(f'Unsupported task {task_type} to print inference results')
197+
raise ValueError(f'Unsupported task {task_type} to print inference results')
199198

200199

201200
def main():
@@ -241,4 +240,4 @@ def main():
241240

242241

243242
if __name__ == '__main__':
244-
sys.exit(main() or 0)
243+
sys.exit(main() or 0)

src/inference/io_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def get_slice_input_executorch(self, *args, **kwargs):
202202
data_gen = self._transformed_input[key]
203203
slice_data = [torch.from_numpy(copy.deepcopy(next(data_gen))) for _ in range(self._batch_size)]
204204
slice_input[key] = torch.stack(slice_data)
205-
return slice_input
205+
return slice_input
206206

207207
def get_result_filename(self, output_path, base_filename):
208208
base_suffix = '.' + base_filename.split('.')[-1]

0 commit comments

Comments
 (0)