Skip to content

Commit 528a9ed

Browse files
committed
[Python API] Add tuple support for opset.constant
Fixes #34210
1 parent 9911cd8 commit 528a9ed

File tree

4 files changed

+42
-4
lines changed

4 files changed

+42
-4
lines changed

src/bindings/python/src/openvino/opset1/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def concat(nodes: list[NodeInput], axis: int, name: Optional[str] = None) -> Nod
333333

334334
@nameable_op
335335
def constant(
336-
value: Union[NumericData, np.number, bool, np.bool_, list],
336+
value: Union[NumericData, np.number, bool, np.bool_, list, tuple],
337337
dtype: Union[NumericType, Type] = None,
338338
name: Optional[str] = None,
339339
) -> Constant:

src/bindings/python/src/openvino/opset13/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def scaled_dot_product_attention(
290290
) # type: ignore
291291
@nameable_op
292292
def constant(
293-
value: Union[NumericData, np.number, bool, np.bool_, list],
293+
value: Union[NumericData, np.number, bool, np.bool_, list, tuple],
294294
dtype: Union[NumericType, Type] = None,
295295
name: Optional[str] = None,
296296
*,

src/bindings/python/src/openvino/utils/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
log = logging.getLogger(__name__)
1717

1818
TensorShape = list[int]
19-
NumericData = Union[int, float, np.ndarray]
19+
NumericData = Union[int, float, np.ndarray, tuple]
2020
NumericType = Union[type, np.dtype]
2121
ScalarData = Union[int, float]
2222
NodeInput = Union[Node, NumericData]
@@ -140,7 +140,7 @@ def get_shape(data: NumericData) -> TensorShape:
140140
"""Return a shape of NumericData."""
141141
if isinstance(data, np.ndarray):
142142
return data.shape # type: ignore
143-
if isinstance(data, list):
143+
if isinstance(data, (list, tuple)):
144144
return [len(data)] # type: ignore
145145
return []
146146

src/bindings/python/tests/test_graph/test_constant.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,3 +796,41 @@ def test_const_from_tensor_with_shared_memory_by_default():
796796
arr += 1
797797
assert np.array_equal(ov_const.data, arr)
798798
assert np.shares_memory(arr, ov_const.data)
799+
800+
801+
def test_constant_from_tuple():
802+
value = (1, 2, 3)
803+
ov_const = ops.constant(value, dtype=np.int32)
804+
805+
assert isinstance(ov_const, Constant)
806+
assert list(ov_const.shape) == [3]
807+
assert ov_const.get_element_type() == Type.i32
808+
assert np.array_equal(ov_const.data, np.array(value, dtype=np.int32))
809+
810+
811+
def test_constant_from_tuple_float():
812+
value = (1.0, 2.5, 3.5)
813+
ov_const = ops.constant(value, dtype=np.float32)
814+
815+
assert isinstance(ov_const, Constant)
816+
assert list(ov_const.shape) == [3]
817+
assert ov_const.get_element_type() == Type.f32
818+
assert np.allclose(ov_const.data, np.array(value, dtype=np.float32))
819+
820+
821+
def test_constant_from_nested_tuple():
822+
value = ((1, 2), (3, 4))
823+
ov_const = ops.constant(value, dtype=np.int32)
824+
825+
assert isinstance(ov_const, Constant)
826+
assert list(ov_const.shape) == [2, 2]
827+
assert np.array_equal(ov_const.data, np.array(value, dtype=np.int32))
828+
829+
830+
def test_constant_from_tuple_no_dtype():
831+
value = (1, 2, 3)
832+
ov_const = ops.constant(value)
833+
834+
assert isinstance(ov_const, Constant)
835+
assert list(ov_const.shape) == [3]
836+
assert np.array_equal(ov_const.data, np.array(value))

0 commit comments

Comments
 (0)