Skip to content

Commit 71a6ac8

Browse files
committed
[interpreter] Fix Circle inference input binding
Use `ModelInputSpec.bind` in `tico/interpreter/infer.py` to produce `user_inputs` so `*args`/`**kwargs` are bound and ordered according to the Circle model input signature. TICO-DCO-1.0-Signed-off-by: seongwoo mhs4670go@naver.com
1 parent a0902ef commit 71a6ac8

3 files changed

Lines changed: 88 additions & 71 deletions

File tree

test/unit_test/utils/test_infer.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,24 @@
1313
# limitations under the License.
1414

1515
import unittest
16+
from typing import ClassVar
17+
from unittest.mock import patch
1618

1719
import numpy as np
1820

1921
import tico
2022
import torch
23+
from tico.interpreter.interpreter import Interpreter
24+
from tico.interpreter import infer as infer_module
2125

2226
from test.modules.op.add import SimpleAdd
2327
from test.modules.op.avg_pool2d import AvgPoolWithPaddingKwargs
2428
from test.modules.op.cat import SimpleCatDefault, SimpleCatWithDim
2529

2630

31+
@unittest.skipUnless(
32+
Interpreter.is_available(), "one-compiler is required for circle inference"
33+
)
2734
class InferSimpleAddTest(unittest.TestCase):
2835
def setUp(self):
2936
# Input: torch.ones(1), torch.ones(1)
@@ -58,6 +65,9 @@ def test_add_float_builtin(self):
5865
)
5966

6067

68+
@unittest.skipUnless(
69+
Interpreter.is_available(), "one-compiler is required for circle inference"
70+
)
6171
class InferCatTest(unittest.TestCase):
6272
def test_concat(self):
6373
# convert
@@ -90,6 +100,9 @@ def test_concat_with_dim(self):
90100
)
91101

92102

103+
@unittest.skipUnless(
104+
Interpreter.is_available(), "one-compiler is required for circle inference"
105+
)
93106
class InferAvgPoolReverseKwargsTest(unittest.TestCase):
94107
def test_avgpool_reverse_kwargs(self):
95108
# convert
@@ -111,3 +124,58 @@ def test_avgpool_reverse_kwargs(self):
111124
np.testing.assert_allclose(
112125
actual=circle_model(**kwargs), desired=out_np, rtol=1e-4, atol=1e-4
113126
)
127+
128+
129+
class FakeInterpreter:
130+
instances: ClassVar[list["FakeInterpreter"]] = []
131+
132+
def __init__(self, circle_binary):
133+
self.circle_binary = circle_binary
134+
self.writes = []
135+
self.interpreted = False
136+
FakeInterpreter.instances.append(self)
137+
138+
def writeInputTensor(self, input_idx, input_data):
139+
self.writes.append((input_idx, input_data.clone()))
140+
141+
def interpret(self):
142+
self.interpreted = True
143+
144+
def readOutputTensor(self, output_idx, output):
145+
output.fill(0)
146+
147+
148+
class InferInputBindingTest(unittest.TestCase):
149+
def test_infer_binds_kwargs_in_model_input_order(self):
150+
m = AvgPoolWithPaddingKwargs()
151+
args, kwargs = m.get_example_inputs()
152+
circle_model = tico.convert(m.eval(), args, kwargs)
153+
154+
tensor0 = torch.randn(2, 4, 8, 16)
155+
tensor1 = torch.randn(2, 4, 4, 8)
156+
FakeInterpreter.instances = []
157+
158+
with patch.object(infer_module, "Interpreter", FakeInterpreter):
159+
infer_module.infer(
160+
circle_model.circle_binary, tensor0=tensor0, tensor1=tensor1
161+
)
162+
163+
fake = FakeInterpreter.instances[0]
164+
self.assertTrue(fake.interpreted)
165+
self.assertEqual(fake.writes[0][0], 0)
166+
self.assertEqual(fake.writes[0][1].shape, tensor1.shape)
167+
self.assertTrue(torch.equal(fake.writes[0][1], tensor1))
168+
self.assertEqual(fake.writes[1][0], 1)
169+
self.assertEqual(fake.writes[1][1].shape, tensor0.shape)
170+
self.assertTrue(torch.equal(fake.writes[1][1], tensor0))
171+
172+
173+
class InterpreterAvailabilityTest(unittest.TestCase):
174+
def test_missing_interpreter_library_does_not_raise_during_cleanup(self):
175+
with patch.object(
176+
Interpreter, "LIB_PATH", Interpreter.LIB_PATH.parent / "missing.so"
177+
):
178+
with self.assertRaisesRegex(
179+
RuntimeError, "Please install one-compiler for circle inference"
180+
):
181+
Interpreter(b"")

tico/interpreter/infer.py

Lines changed: 6 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,82 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Sequence
15+
from typing import Any
1616

1717
import numpy as np
18-
import torch
1918
from circle_schema import circle
2019

2120
from tico.interpreter.interpreter import Interpreter
22-
from tico.serialize.circle_mapping import np_dtype_from_circle_dtype, to_circle_dtype
23-
from tico.utils.installed_packages import is_dynamic_cache_available
24-
25-
26-
def flatten_and_convert(inputs: Sequence) -> tuple:
27-
result = [] # type: ignore[var-annotated]
28-
for item in inputs:
29-
if item is None:
30-
continue
31-
32-
# 1. recursion on list and tuple
33-
if isinstance(item, (list, tuple)):
34-
result.extend(flatten_and_convert(item))
35-
continue
36-
37-
# 2. handle DynamicCache
38-
if is_dynamic_cache_available():
39-
from transformers.cache_utils import DynamicCache
40-
41-
if isinstance(item, DynamicCache):
42-
# NOTE The tensor order is: key_in → key_out → value_in → value_out
43-
#
44-
# Refer to https://github.com/huggingface/transformers/blob/3457e8e73e4f5532cc69059682b1ba4484d7e7e8/src/transformers/cache_utils.py#L557
45-
# ```py
46-
# self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
47-
# self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
48-
# ```
49-
result.extend(item.key_cache)
50-
result.extend(item.value_cache)
51-
continue
52-
53-
# 3. Convert to tensors
54-
result.append(item if isinstance(item, torch.Tensor) else torch.tensor(item))
55-
56-
return tuple(result)
21+
from tico.serialize.circle_mapping import np_dtype_from_circle_dtype
22+
from tico.utils.signature import ModelInputSpec
5723

5824

5925
def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any:
60-
# When converting a model, it is assumed that the order of keyword arguments is maintained.
61-
raw_inputs = args + tuple(kwargs.values())
62-
user_inputs = flatten_and_convert(raw_inputs)
26+
input_spec = ModelInputSpec(circle_binary)
27+
user_inputs = input_spec.bind(args, kwargs, check=True)
6328

64-
# Get input spec from circle binary.
29+
# Get input/output spec from circle binary.
6530
model = circle.Model.Model.GetRootAsModel(circle_binary, 0)
6631
assert model.SubgraphsLength() == 1
6732
graph = model.Subgraphs(0)
68-
model_input_tensors = [
69-
graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())
70-
]
71-
model_input_shapes_np = [t.ShapeAsNumpy() for t in model_input_tensors]
72-
model_input_types_cm = [t.Type() for t in model_input_tensors]
73-
74-
# Check if given inputs' dtype and shape from users match the inputs' from model binary.
75-
if len(model_input_shapes_np) != len(user_inputs):
76-
raise RuntimeError(
77-
f"Mismatch input length: input({len(user_inputs)}) != circle model({len(model_input_shapes_np)})"
78-
)
79-
for input_idx, user_input in enumerate(user_inputs):
80-
# Shape
81-
if list(user_input.shape) != list(model_input_shapes_np[input_idx]):
82-
raise RuntimeError(
83-
f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_input_shapes_np[input_idx]})"
84-
)
85-
# Data type
86-
user_input_type_cm = to_circle_dtype(user_input.dtype)
87-
if user_input_type_cm != model_input_types_cm[input_idx]:
88-
raise RuntimeError(
89-
f"Mismatch input {input_idx} data type : input({user_input_type_cm}) != circle model({model_input_types_cm[input_idx]})"
90-
)
9133

9234
# Initialize interpreter
9335
intp = Interpreter(circle_binary)

tico/interpreter/interpreter.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from pathlib import Path
16+
from typing import ClassVar
1617

1718
import numpy as np
1819
import torch
@@ -33,10 +34,15 @@ class Interpreter:
3334
library do not cause undefined behavior in Python.
3435
"""
3536

37+
LIB_PATH: ClassVar[Path] = Path("/usr/share/one/lib/libcircle_interpreter_cffi.so")
38+
39+
@classmethod
40+
def is_available(cls) -> bool:
41+
return cls.LIB_PATH.is_file()
42+
3643
def __init__(self, circle_binary: bytes):
3744
self.ffi = FFI()
38-
self.ffi.cdef(
39-
"""
45+
self.ffi.cdef("""
4046
typedef struct InterpreterWrapper InterpreterWrapper;
4147
4248
const char *get_last_error(void);
@@ -46,21 +52,22 @@ def __init__(self, circle_binary: bytes):
4652
void Interpreter_interpret(InterpreterWrapper *intp);
4753
void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data, size_t input_size);
4854
void Interpreter_readOutputTensor(InterpreterWrapper *intp, const int output_idx, void *output, size_t output_size);
49-
"""
50-
)
55+
""")
5156
# TODO Check if one-compiler version is compatible. Whether it has .so file or not for CFFI.
52-
intp_lib_path = Path("/usr/share/one/lib/libcircle_interpreter_cffi.so")
53-
if not intp_lib_path.is_file():
57+
if not self.is_available():
5458
raise RuntimeError("Please install one-compiler for circle inference.")
55-
self.C = self.ffi.dlopen(str(intp_lib_path))
59+
self.C = self.ffi.dlopen(str(self.LIB_PATH))
5660

5761
# Initialize interpreter
5862
self.intp = self.C.Interpreter_new(circle_binary, len(circle_binary))
5963
self.check_for_errors()
6064

6165
def delete(self):
66+
if not hasattr(self, "C") or not hasattr(self, "intp"):
67+
return
6268
self.C.Interpreter_delete(self.intp)
6369
self.check_for_errors()
70+
del self.intp
6471

6572
def interpret(self):
6673
self.C.Interpreter_interpret(self.intp)

0 commit comments

Comments
 (0)