Skip to content

Commit ede5181

Browse files
committed
[draft] Enhance python inference API
ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
1 parent c6eb5a0 commit ede5181

4 files changed

Lines changed: 131 additions & 26 deletions

File tree

runtime/onert/api/python/package/common/basesession.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,32 @@ def _recreate_session(self, backend_session):
5252
del self.session # Clean up the existing session
5353
self.session = backend_session
5454

55+
def get_inputs_tensorinfo(self) -> list:
56+
"""
57+
Retrieve tensorinfo for all input tensors.
58+
59+
Returns:
60+
list: A list of tensorinfo objects for each input.
61+
"""
62+
num_inputs = self.session.input_size()
63+
infos = []
64+
for i in range(num_inputs):
65+
infos.append(self.session.input_tensorinfo(i))
66+
return infos
67+
68+
def get_outputs_tensorinfo(self) -> list:
69+
"""
70+
Retrieve tensorinfo for all output tensors.
71+
72+
Returns:
73+
list: A list of tensorinfo objects for each output.
74+
"""
75+
num_outputs = self.session.output_size()
76+
infos = []
77+
for i in range(num_outputs):
78+
infos.append(self.session.output_tensorinfo(i))
79+
return infos
80+
5581
def set_inputs(self, size, inputs_array=[]):
5682
"""
5783
Set the input tensors for the session.

runtime/onert/api/python/package/infer/session.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,64 @@ class session(BaseSession):
66
"""
77
Class for inference using nnfw_session.
88
"""
9-
def __init__(self, path: str = None, backends: str = "cpu"):
9+
def __init__(self, path: str, backends: str = "cpu"):
1010
"""
1111
Initialize the inference session.
1212
Args:
1313
path (str): Path to the model file or nnpackage directory.
1414
backends (str): Backends to use, default is "cpu".
1515
"""
16-
if path is not None:
17-
super().__init__(libnnfw_api_pybind.infer.nnfw_session(path, backends))
18-
self.session.prepare()
19-
self.set_outputs(self.session.output_size())
20-
else:
21-
super().__init__()
16+
super().__init__(libnnfw_api_pybind.infer.nnfw_session(path, backends))
17+
self._prepared = False
2218

23-
def compile(self, path: str, backends: str = "cpu"):
19+
def update_inputs_tensorinfo(self, new_infos: list):
2420
"""
25-
Prepare the session by recreating it with new parameters.
21+
Update all input tensors' tensorinfo at once.
22+
2623
Args:
27-
path (str): Path to the model file or nnpackage directory. Defaults to the existing path.
28-
backends (str): Backends to use. Defaults to the existing backends.
24+
new_infos (list): A list of updated tensorinfo objects for the inputs.
25+
Raises:
26+
ValueError: If the number of new_infos does not match the session's input size.
2927
"""
30-
# Update parameters if provided
31-
if path is None:
32-
raise ValueError("path must not be None.")
33-
# Recreate the session with updated parameters
34-
self._recreate_session(libnnfw_api_pybind.infer.nnfw_session(path, backends))
35-
# Prepare the new session
36-
self.session.prepare()
37-
self.set_outputs(self.session.output_size())
38-
39-
def inference(self):
28+
num_inputs = self.session.input_size()
29+
if len(new_infos) != num_inputs:
30+
raise ValueError(
31+
f"Expected {num_inputs} input tensorinfo(s), but got {len(new_infos)}.")
32+
for i, info in enumerate(new_infos):
33+
self.session.set_input_tensorinfo(i, info)
34+
35+
def run_inference(self, inputs_array: list) -> list:
4036
"""
41-
Perform model and get outputs
37+
Run a complete inference cycle:
38+
- If the session hhas not been prepared or output buffers have not been set, call prepare() and set_outputs() once.
39+
- Automatically configure input buffers based on the provided numpy arrays.
40+
- Execute the inference session by calling nnfw_run().
41+
- Return the output tensors with proper multi-dimensional shapes.
42+
43+
This method supports both static and dynamic shape modification:
44+
- If update_inputs_tensorinfo() has been called before running inference, the model is compiled
45+
with the fixed static input shape.
46+
- Otherwise, the input shapes are adjusted dynamically during nnfw_run().
47+
48+
Args:
49+
inputs_array (list): List of numpy arrays representing the input data.
4250
Returns:
43-
list: Outputs from the model.
51+
list: A list containing te output numpy arrays.
4452
"""
53+
# Check if the session is prepared; if not, call prepare() and set_outputs() once.
54+
if not self._prepared:
55+
self.session.prepare()
56+
self.set_outputs(self.session.output_size())
57+
self._prepared = True
58+
59+
# Verify that the number of provided inputs matches the session's expected input count.
60+
if len(inputs_array) != self.session.input_size():
61+
raise ValueError(
62+
f"Expected {self.session.input_size()} input(s), but received {len(inputs_array)}."
63+
)
64+
# Configure input buffers using the current session's input size and provided data.
65+
self.set_inputs(self.session.input_size(), inputs_array)
66+
# Execute the inference.
4567
self.session.run()
68+
# Return the output buffers.
4669
return self.outputs

runtime/onert/sample/minimal-python/src/minimal.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from onert import infer
2+
import numpy as np
23
import sys
34

45

@@ -8,10 +9,17 @@ def main(nnpackage_path, backends="cpu"):
89
session = infer.session(nnpackage_path, backends)
910

1011
# Prepare input. Here we just allocate dummy input arrays.
11-
input_size = session.input_size()
12-
session.set_inputs(input_size)
12+
input_infos = session.get_inputs_tensorinfo()
13+
dummy_inputs = []
14+
for info in input_infos:
15+
# Retrieve the dimensions list from tensorinfo property.
16+
dims = list(info.dims)
17+
# Build the shape tuple from tensorinfo dimensions.
18+
shape = tuple(dims[:info.rank])
19+
# Create a dummy numpy array filled with zeros.
20+
dummy_inputs.append(np.zeros(shape, dtype=info.dtype))
1321

14-
outputs = session.inference()
22+
outputs = session.run_inference(dummy_inputs)
1523

1624
print(f"nnpackage {nnpackage_path.split('/')[-1]} runs successfully.")
1725
return
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from onert import infer
2+
import numpy as np
3+
import sys
4+
5+
def main(nnpackage_path, backends="cpu"):
6+
# Create session and load the nnpackage
7+
sess = infer.session(nnpackage_path, backends)
8+
9+
# Retrieve the current tensorinfo for all inputs.
10+
current_input_infos = sess.get_inputs_tensorinfo()
11+
12+
# Create new tensorinfo objects with a static shape modification.
13+
# For this example, assume we change the first dimension (e.g., batch size) to 10.
14+
new_input_infos = []
15+
for info in current_input_infos:
16+
# For example, if the current shape is (?, 4), update it to (10, 4).
17+
# We copy the current info and modify the rank and dims.
18+
# (Note: Depending on your model, you may want to modify additional dimensions.)
19+
new_shape = [10] + list(info.dims[1:info.rank])
20+
info.rank = len(new_shape)
21+
for i, dim in enumerate(new_shape):
22+
info.dims[i] = dim
23+
# For any remaining dimensions up to NNFW_MAX_RANK, set them to a default (1).
24+
for i in range(len(new_shape), len(info.dims)):
25+
info.dims[i] = 1
26+
new_input_infos.append(info)
27+
28+
# Update all input tensorinfos in the session at once.
29+
# This will call prepare() and set_outputs() internally.
30+
sess.update_inputs_tensorinfo(new_input_infos)
31+
32+
# Create dummy input arrays based on the new static shapes.
33+
dummy_inputs = []
34+
for info in new_input_infos:
35+
# Build the shape tuple from tensorinfo dimensions.
36+
shape = tuple(info.dims[:info.rank])
37+
# Create a dummy numpy array filled with zeros.
38+
dummy_inputs.append(np.zeros(shape, dtype=info.dtype))
39+
40+
# Run inference with the new static input shapes.
41+
outputs = sess.run_inference(dummy_inputs)
42+
43+
print(f"Static shape modification sample: nnpackage {nnpackage_path.split('/')[-1]} runs successfully.")
44+
return
45+
46+
if __name__ == "__main__":
47+
argv = sys.argv[1:]
48+
main(*argv)

0 commit comments

Comments
 (0)