Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bindings/python/src/openvino/runtime/ie_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def compile_model(
device_name: Optional[str] = None,
config: Optional[dict] = None,
*,
weights: Optional[bytes] = None
weights: Optional[bytes] = None,
) -> CompiledModel:
"""Creates a compiled model.

Expand Down
17 changes: 13 additions & 4 deletions src/bindings/python/src/openvino/runtime/opset6/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ def mvn(
return _get_node_factory_opset6().create("MVN", inputs, attributes)


@overloading(Union[Node, Output], str, Optional[Union[type, np.dtype, Type, str]], Optional[Union[TensorShape, Shape, PartialShape]], Optional[str])
@overloading(Union[Node, Output, int, float, np.ndarray], str, Optional[Union[type, np.dtype, Type, str]],
Optional[Union[TensorShape, Shape, PartialShape]], Optional[str])
@nameable_op
def read_value(init_value: Union[Node, Output],
def read_value(init_value: Union[Node, Output, int, float, np.ndarray],
variable_id: str,
variable_type: Optional[Union[type, np.dtype, Type, str]] = None,
variable_shape: Optional[Union[TensorShape, Shape, PartialShape]] = None,
Expand All @@ -139,9 +140,13 @@ def read_value(init_value: Union[Node, Output],
info.data_type = get_element_type(variable_type)
else:
info.data_type = variable_type
else:
info.data_type = Type.dynamic

if variable_shape is not None:
info.data_shape = PartialShape(variable_shape)
else:
info.data_shape = PartialShape.dynamic()

var_from_info = Variable(info)
return _read_value(new_value=as_node(init_value, name=name), variable=var_from_info)
Expand Down Expand Up @@ -169,9 +174,13 @@ def read_value(variable_id: str, # noqa: F811
info.data_type = get_element_type(variable_type)
else:
info.data_type = variable_type
else:
info.data_type = Type.dynamic

if variable_shape is not None:
info.data_shape = PartialShape(variable_shape)
else:
info.data_shape = PartialShape.dynamic()

var_from_info = Variable(info)

Expand All @@ -191,9 +200,9 @@ def read_value(ov_variable: Variable, # noqa: F811
return _read_value(ov_variable)


@overloading(Union[Node, Output], Variable, Optional[str]) # type: ignore
@overloading(Union[Node, Output, int, float, np.ndarray], Variable, Optional[str]) # type: ignore
@nameable_op
def read_value(init_value: Union[Node, Output], # noqa: F811
def read_value(init_value: Union[Node, Output, int, float, np.ndarray], # noqa: F811
ov_variable: Variable,
name: Optional[str] = None) -> Node:
"""Return a node which produces the Assign operation.
Expand Down
2 changes: 1 addition & 1 deletion src/bindings/python/src/openvino/runtime/opset8/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def random_uniform(
"output_type": output_type,
"global_seed": global_seed,
"op_seed": op_seed,
"alignment": alignment.lower()
"alignment": alignment.lower(),
}
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)

Expand Down
38 changes: 29 additions & 9 deletions src/bindings/python/src/openvino/runtime/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from functools import wraps
from inspect import getfullargspec
from inspect import signature
from typing import Any, Callable, Dict, Optional, Union, get_origin, get_args

from openvino.runtime import Node, Output
Expand Down Expand Up @@ -102,19 +102,39 @@ def check_invoked_types_in_overloaded_funcs(self, tuple_to_check: tuple, key_str
return False
return True

def __call__(self, *args) -> Any: # type: ignore
types = tuple(arg.__class__ for arg in args)
def __call__(self, *args, **kwargs) -> Any: # type: ignore
arg_types = tuple(arg.__class__ for arg in args)
kwarg_types = {key: type(value) for key, value in kwargs.items()}

key_matched = None
for key in self.typemap.keys():
if self.check_invoked_types_in_overloaded_funcs(types, key):
key_matched = key
break
if len(kwarg_types) == 0 and len(arg_types) != 0:
for key in self.typemap.keys():
# compare types of called function with overloads
if self.check_invoked_types_in_overloaded_funcs(arg_types, key):
key_matched = key
break
elif len(arg_types) == 0 and len(kwarg_types) != 0:
for key, func in self.typemap.items():
func_signature = {arg_name: types.annotation for arg_name, types in signature(func).parameters.items()}
# if kwargs of called function are subset of overloaded function, we use this overload
if kwarg_types.keys() <= func_signature.keys():
key_matched = key
break
elif len(arg_types) != 0 and len(kwarg_types) != 0:
for key, func in self.typemap.items():
func_signature = {arg_name: types.annotation for arg_name, types in signature(func).parameters.items()}
# compare types of called function with overloads
if self.check_invoked_types_in_overloaded_funcs(arg_types, tuple(func_signature.values())):
# if kwargs of called function are subset of overloaded function, we use this overload
if kwarg_types.keys() <= func_signature.keys():
key_matched = key
break

if key_matched is None:
raise TypeError("no match")
raise TypeError(f"The necessary overload for {self.name} was not found")

function = self.typemap.get(key_matched)
return function(*args) # type: ignore
return function(*args, **kwargs) # type: ignore

def register(self, types: tuple, function: Callable) -> None:
if types in self.typemap:
Expand Down
27 changes: 27 additions & 0 deletions src/bindings/python/tests/test_graph/test_create_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import openvino.runtime.opset5 as ov_opset5
import openvino.runtime.opset10 as ov_opset10
import openvino.runtime.opset11 as ov
from openvino.runtime.op.util import VariableInfo, Variable

np_types = [np.float32, np.int32]
integral_np_types = [
Expand Down Expand Up @@ -1218,6 +1219,32 @@ def test_read_value():
assert read_value_attributes["variable_shape"] == [2, 2]


def test_read_value_ctors():
data = np.ones((1, 64), dtype=np.float32)
# check mixed args&kwargs creation
read_value = ov.read_value(data, "variable_id_1", name="read_value")
assert read_value.friendly_name == "read_value"

var_info = VariableInfo()
var_info.data_shape = PartialShape([1, 64])
var_info.data_type = Type.f32
var_info.variable_id = "v1"
variable_1 = Variable(var_info)

# check kwargs creation
read_value_1 = ov.read_value(init_value=data, ov_variable=variable_1)
assert list(read_value_1.get_output_shape(0)) == [1, 64]

# check args creation
read_value_2 = ov.read_value(variable_1)
assert list(read_value_2.get_output_shape(0)) == [1, 64]

with pytest.raises(TypeError) as e:
ov.read_value(data, "variable_id_1", 2)

assert "The necessary overload for read_value was not found" in str(e.value)


def test_read_value_dyn_variable_pshape():
init_value = ov.parameter([2, 2], name="init_value", dtype=np.int32)

Expand Down