Skip to content

Commit c2b16cf

Browse files
committed
codestyle fix
1 parent d31ce6b commit c2b16cf

File tree

1 file changed

+27
-33
lines changed

1 file changed

+27
-33
lines changed

src/inference/inference_executorch.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,18 @@
11
import argparse
2-
import importlib
32
import json
4-
import re
53
import sys
64
import traceback
75

8-
from functools import partial
96
from pathlib import Path
107
from time import time
118

129
import postprocessing_data as pp
13-
import preprocessing_data as prep
1410
from pathlib import Path
15-
import numpy as np
16-
from executorch.runtime import Verification, Runtime, Program, Method
11+
from executorch.runtime import Verification, Runtime
1712
from inference_tools.loop_tools import loop_inference, get_exec_time
1813
from io_adapter import IOAdapter
1914
from io_model_wrapper import ExecuTorchIOModelWrapper
2015
from reporter.report_writer import ReportWriter
21-
from configs.config_utils import prepend_to_path, to_camel_case, get_model_config
2216
from transformer import ExecuTorchTransformer
2317
from tvm_auxiliary import create_dict_for_transformer, create_dict_for_modelwrapper
2418

@@ -27,6 +21,7 @@
2721

2822
log = configure_logger()
2923

24+
3025
def cli_argument_parser():
3126
parser = argparse.ArgumentParser()
3227
parser.add_argument('-mn', '--model_name',
@@ -158,27 +153,27 @@ def cli_argument_parser():
158153

159154

160155
def load_model(model_path):
161-
et_runtime = Runtime.get()
162-
program = et_runtime.load_program(
163-
model_path,
164-
verification=Verification.Minimal,
165-
)
166-
return program.load_method("forward")
156+
et_runtime = Runtime.get()
157+
program = et_runtime.load_program(
158+
model_path,
159+
verification=Verification.Minimal,
160+
)
161+
return program.load_method("forward")
167162

168163

169164
def inference_executorch(net, num_iterations, get_slice, input_name, test_duration):
170-
predictions = None
171-
time_infer = []
172-
if num_iterations == 1:
173-
t0 = time()
174-
slice_input = get_slice()
175-
predictions = net.execute((slice_input[input_name],))
176-
t1 = time()
177-
time_infer.append(t1 - t0)
178-
else:
179-
loop_results = loop_inference(num_iterations, test_duration)(inference_iteration)(get_slice, input_name, net)
180-
time_infer = loop_results['time_infer']
181-
return predictions, time_infer
165+
predictions = None
166+
time_infer = []
167+
if num_iterations == 1:
168+
t0 = time()
169+
slice_input = get_slice()
170+
predictions = net.execute((slice_input[input_name],))
171+
t1 = time()
172+
time_infer.append(t1 - t0)
173+
else:
174+
loop_results = loop_inference(num_iterations, test_duration)(inference_iteration)(get_slice, input_name, net)
175+
time_infer = loop_results['time_infer']
176+
return predictions, time_infer
182177

183178

184179
def inference_iteration(get_slice, input_name, net):
@@ -194,13 +189,13 @@ def infer_slice(input_name, net, slice_input):
194189

195190

196191
def prepare_output(result, output_names, task_type):
197-
if task_type in ['feedforward']:
198-
return {}
199-
elif task_type in ['classification']:
200-
log.info('Converting output tensor to print results')
201-
return {output_names[0]: result[0].numpy()}
202-
else:
203-
raise ValueError(f'Unsupported task {task_type} to print inference results')
192+
if task_type in ['feedforward']:
193+
return {}
194+
elif task_type in ['classification']:
195+
log.info('Converting output tensor to print results')
196+
return {output_names[0]: result[0].numpy()}
197+
else:
198+
raise ValueError(f'Unsupported task {task_type} to print inference results')
204199

205200

206201
def main():
@@ -215,7 +210,6 @@ def main():
215210
io = IOAdapter.get_io_adapter(args, wrapper, transformer)
216211

217212
log.info(f'Shape for input layer {args.input_name}: {args.input_shape}')
218-
#report_writer.update_framework_info(name='ExecuTorch', version=executorch.__version__)
219213
net = load_model(args.model_path)
220214

221215
log.info(f'Preparing input data: {args.input}')

0 commit comments

Comments
 (0)