@@ -6,41 +6,64 @@ class session(BaseSession):
66 """
77 Class for inference using nnfw_session.
88 """
9- def __init__ (self , path : str = None , backends : str = "cpu" ):
9+ def __init__ (self , path : str , backends : str = "cpu" ):
1010 """
1111 Initialize the inference session.
1212 Args:
1313 path (str): Path to the model file or nnpackage directory.
1414 backends (str): Backends to use, default is "cpu".
1515 """
16- if path is not None :
17- super ().__init__ (libnnfw_api_pybind .infer .nnfw_session (path , backends ))
18- self .session .prepare ()
19- self .set_outputs (self .session .output_size ())
20- else :
21- super ().__init__ ()
16+ super ().__init__ (libnnfw_api_pybind .infer .nnfw_session (path , backends ))
17+ self ._prepared = False
2218
23- def compile (self , path : str , backends : str = "cpu" ):
19+ def update_inputs_tensorinfo (self , new_infos : list ):
2420 """
25- Prepare the session by recreating it with new parameters.
21+ Update all input tensors' tensorinfo at once.
22+
2623 Args:
27- path (str): Path to the model file or nnpackage directory. Defaults to the existing path.
28- backends (str): Backends to use. Defaults to the existing backends.
24+ new_infos (list): A list of updated tensorinfo objects for the inputs.
25+ Raises:
26+ ValueError: If the number of new_infos does not match the session's input size.
2927 """
30- # Update parameters if provided
31- if path is None :
32- raise ValueError ("path must not be None." )
33- # Recreate the session with updated parameters
34- self ._recreate_session (libnnfw_api_pybind .infer .nnfw_session (path , backends ))
35- # Prepare the new session
36- self .session .prepare ()
37- self .set_outputs (self .session .output_size ())
38-
39- def inference (self ):
28+ num_inputs = self .session .input_size ()
29+ if len (new_infos ) != num_inputs :
30+ raise ValueError (
31+ f"Expected { num_inputs } input tensorinfo(s), but got { len (new_infos )} ." )
32+ for i , info in enumerate (new_infos ):
33+ self .session .set_input_tensorinfo (i , info )
34+
35+ def run_inference (self , inputs_array : list ) -> list :
4036 """
41- Perform model and get outputs
37+ Run a complete inference cycle:
38+ - If the session hhas not been prepared or output buffers have not been set, call prepare() and set_outputs() once.
39+ - Automatically configure input buffers based on the provided numpy arrays.
40+ - Execute the inference session by calling nnfw_run().
41+ - Return the output tensors with proper multi-dimensional shapes.
42+
43+ This method supports both static and dynamic shape modification:
44+ - If update_inputs_tensorinfo() has been called before running inference, the model is compiled
45+ with the fixed static input shape.
46+ - Otherwise, the input shapes are adjusted dynamically during nnfw_run().
47+
48+ Args:
49+ inputs_array (list): List of numpy arrays representing the input data.
4250 Returns:
43- list: Outputs from the model .
51+ list: A list containing te output numpy arrays .
4452 """
53+ # Check if the session is prepared; if not, call prepare() and set_outputs() once.
54+ if not self ._prepared :
55+ self .session .prepare ()
56+ self .set_outputs (self .session .output_size ())
57+ self ._prepared = True
58+
59+ # Verify that the number of provided inputs matches the session's expected input count.
60+ if len (inputs_array ) != self .session .input_size ():
61+ raise ValueError (
62+ f"Expected { self .session .input_size ()} input(s), but received { len (inputs_array )} ."
63+ )
64+ # Configure input buffers using the current session's input size and provided data.
65+ self .set_inputs (self .session .input_size (), inputs_array )
66+ # Execute the inference.
4567 self .session .run ()
68+ # Return the output buffers.
4669 return self .outputs
0 commit comments