Skip to content

Commit 8c87f5d

Browse files
[OpenVINO Backend] Support np.diag (#20967)
* Support np.diag() for openvino * Update excluded tests file and fix code format * Fix formatting issue * Exclude diagonal and diagflat tests
1 parent ce9f968 commit 8c87f5d

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ NumpyDtypeTest::test_corrcoef
1919
NumpyDtypeTest::test_correlate
2020
NumpyDtypeTest::test_cross
2121
NumpyDtypeTest::test_cumprod
22-
NumpyDtypeTest::test_diag
22+
NumpyDtypeTest::test_diagflat
23+
NumpyDtypeTest::test_diagonal
2324
NumpyDtypeTest::test_einsum
2425
NumpyDtypeTest::test_exp2
2526
NumpyDtypeTest::test_flip
@@ -69,7 +70,7 @@ NumpyOneInputOpsCorrectnessTest::test_conj
6970
NumpyOneInputOpsCorrectnessTest::test_corrcoef
7071
NumpyOneInputOpsCorrectnessTest::test_correlate
7172
NumpyOneInputOpsCorrectnessTest::test_cumprod
72-
NumpyOneInputOpsCorrectnessTest::test_diag
73+
NumpyOneInputOpsCorrectnessTest::test_diagflat
7374
NumpyOneInputOpsCorrectnessTest::test_diagonal
7475
NumpyOneInputOpsCorrectnessTest::test_exp2
7576
NumpyOneInputOpsCorrectnessTest::test_flip

keras/src/backend/openvino/numpy.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,70 @@ def deg2rad(x):
760760

761761

762762
def diag(x, k=0):
763-
raise NotImplementedError("`diag` is not supported with openvino backend")
763+
x = get_ov_output(x)
764+
x_shape = x.get_partial_shape()
765+
rank = x_shape.rank.get_length()
766+
767+
if rank == 1:
768+
N_dim = x_shape[0]
769+
if not N_dim.is_static:
770+
raise ValueError(
771+
"diag requires input with static shape for 1D input."
772+
)
773+
N = N_dim.get_length()
774+
output_size = N + np.abs(k)
775+
out_shape = ov_opset.constant(
776+
[output_size, output_size], dtype=Type.i32
777+
).output(0)
778+
zeros_const = ov_opset.constant(0, x.get_element_type()).output(0)
779+
diag_matrix = ov_opset.broadcast(zeros_const, out_shape)
780+
781+
indices = []
782+
if k >= 0:
783+
for i in range(N):
784+
indices.append([i, i + k])
785+
else:
786+
for i in range(N):
787+
indices.append([i - k, i])
788+
789+
indices = np.array(indices, dtype=np.int32)
790+
indices_const = ov_opset.constant(indices, dtype=Type.i32).output(0)
791+
updated = ov_opset.scatter_nd_update(diag_matrix, indices_const, x)
792+
return OpenVINOKerasTensor(updated.output(0))
793+
794+
elif rank == 2:
795+
M_dim = x_shape[0]
796+
N_dim = x_shape[1]
797+
if not M_dim.is_static or not N_dim.is_static:
798+
raise ValueError(
799+
"diag requires input with static shape for 2D input."
800+
)
801+
M = M_dim.get_length()
802+
N = N_dim.get_length()
803+
804+
if k >= 0:
805+
L = np.minimum(M, N - k) if (N - k) > 0 else 0
806+
indices = [[i, i + k] for i in range(L)]
807+
else:
808+
L = np.minimum(M + k, N) if (M + k) > 0 else 0
809+
indices = [[i - k, i] for i in range(L)]
810+
811+
if L <= 0:
812+
keras_dtype = ov_to_keras_type(x.get_element_type())
813+
np_dtype = np.dtype(keras_dtype)
814+
empty_np = np.empty((0,), dtype=np_dtype)
815+
empty_const = ov_opset.constant(
816+
empty_np, x.get_element_type()
817+
).output(0)
818+
return OpenVINOKerasTensor(empty_const)
819+
820+
indices = np.array(indices, dtype=np.int32)
821+
indices_const = ov_opset.constant(indices, dtype=Type.i32).output(0)
822+
diag_vec = ov_opset.gather_nd(x, indices_const)
823+
return OpenVINOKerasTensor(diag_vec.output(0))
824+
825+
else:
826+
raise ValueError("diag supports only 1D or 2D tensors")
764827

765828

766829
def diagonal(x, offset=0, axis1=0, axis2=1):

0 commit comments

Comments
 (0)