Skip to content

Commit bd650a0

Browse files
committed
using ov opsets
1 parent 5370493 commit bd650a0

File tree

1 file changed

+25
-8
lines changed

1 file changed

+25
-8
lines changed

keras/src/backend/openvino/linalg.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import numpy as np
12
import openvino.opset15 as ov_opset
23
from openvino import Type
34

45
from keras.src.backend import config
56
from keras.src.backend import standardize_dtype
67
from keras.src.backend.common import dtypes
8+
from keras.src.backend.openvino.core import OPENVINO_DTYPES
79
from keras.src.backend.openvino.core import OpenVINOKerasTensor
810
from keras.src.backend.openvino.core import cast
11+
from keras.src.backend.openvino.core import convert_to_numpy
912
from keras.src.backend.openvino.core import convert_to_tensor
1013
from keras.src.backend.openvino.core import get_ov_output
1114

@@ -42,24 +45,38 @@ def eig(a):
4245

4346

4447
def eigh(a):
45-
import numpy as np
46-
from keras.src.backend.openvino.core import convert_to_numpy
4748
a = convert_to_tensor(a)
49+
a_ov = get_ov_output(a)
50+
51+
52+
a_ov_type = a_ov.get_element_type()
53+
if not a_ov_type.is_real():
54+
55+
a_ov = ov_opset.convert(a_ov, Type.f32).output(0)
56+
out_ov_type = Type.f32
57+
else:
58+
out_ov_type = a_ov_type
59+
60+
61+
a_evaluated = OpenVINOKerasTensor(a_ov)
4862
try:
49-
a_np = convert_to_numpy(a)
63+
a_np = convert_to_numpy(a_evaluated)
5064
except Exception as e:
5165
raise ValueError(
5266
"eigh is only supported for static eager tensors "
5367
"in the openvino backend. Received a dynamic or symbolic tensor."
5468
) from e
55-
56-
w, v = np.linalg.eigh(a_np)
69+
70+
71+
w_np, v_np = np.linalg.eigh(a_np)
72+
73+
w_const = ov_opset.constant(w_np, out_ov_type).output(0)
74+
v_const = ov_opset.constant(v_np, out_ov_type).output(0)
5775
return (
58-
ov_opset.constant(w).output(0),
59-
ov_opset.constant(v).output(0),
76+
OpenVINOKerasTensor(w_const),
77+
OpenVINOKerasTensor(v_const),
6078
)
6179

62-
6380
def inv(a):
6481
a = convert_to_tensor(a)
6582
a_ov = get_ov_output(a)

0 commit comments

Comments
 (0)