diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index 61db97b50c37..c9319e7e8179 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -346,7 +346,6 @@ LayerTest::test_end_to_end_masking LayerTest::test_quantized_layer_with_remat LayerTest::test_stateless_call LinalgOpsCorrectnessTest::test_cholesky -LinalgOpsCorrectnessTest::test_det LinalgOpsCorrectnessTest::test_eig LinalgOpsCorrectnessTest::test_eigh LinalgOpsCorrectnessTest::test_lstsq diff --git a/keras/src/backend/openvino/linalg.py b/keras/src/backend/openvino/linalg.py index 1ecce8303959..541e348a2afb 100644 --- a/keras/src/backend/openvino/linalg.py +++ b/keras/src/backend/openvino/linalg.py @@ -1,3 +1,4 @@ +import openvino as ov import openvino.opset15 as ov_opset from openvino import Type @@ -34,7 +35,120 @@ def cholesky_inverse(a, upper=False): def det(a): - raise NotImplementedError("`det` is not supported with openvino backend") + a = convert_to_tensor(a) + a_ov = get_ov_output(a) + original_type = a_ov.get_element_type() + + # Avoid constant folding bug for f64 in OpenVINO CPU Loop evaluate + if original_type == Type.f64: + a_ov = ov_opset.convert(a_ov, Type.f32).output(0) + + a_shape = ov_opset.shape_of(a_ov, output_type="i32").output(0) + + rank = a_ov.get_partial_shape().rank.get_length() + + minus_1 = ov_opset.constant([-1], Type.i32).output(0) + minus_2 = ov_opset.constant([-2], Type.i32).output(0) + + N_node_1d = ov_opset.gather( + a_shape, minus_1, ov_opset.constant(0, Type.i32).output(0) + ).output(0) + N_node_scalar = ov_opset.squeeze( + N_node_1d, ov_opset.constant([0], Type.i32).output(0) + ).output(0) + + num_batch_dims = rank - 2 + if num_batch_dims > 0: + batch_dims_shape = ov_opset.broadcast( + ov_opset.constant([1], Type.i32).output(0), + ov_opset.constant([num_batch_dims], Type.i32).output(0), + ).output(0) + eye_shape = ov_opset.concat( + [batch_dims_shape, N_node_1d, N_node_1d], 0 + ).output(0) + else: + eye_shape = ov_opset.concat([N_node_1d, N_node_1d], 0).output(0) + + eye = ov_opset.eye( + N_node_scalar, + N_node_scalar, + ov_opset.constant(0, Type.i32).output(0), + a_ov.get_element_type(), + ).output(0) + eye_reshaped = ov_opset.reshape(eye, eye_shape, False).output(0) + + trip_count = N_node_scalar + loop = ov_opset.loop( + trip_count, ov_opset.constant(True, Type.boolean).output(0) + ) + + M_param = ov_opset.parameter([-1] * rank, a_ov.get_element_type(), "M") + k_param = ov_opset.parameter([], Type.i32, "k") + A_body_param = ov_opset.parameter( + [-1] * rank, a_ov.get_element_type(), "A_body" + ) + eye_body_param = ov_opset.parameter( + [-1] * rank, a_ov.get_element_type(), "eye_body" + ) + + k_next = ov_opset.add( + k_param.output(0), ov_opset.constant(1, Type.i32).output(0) + ).output(0) + k_f32 = ov_opset.convert(k_next, a_ov.get_element_type()).output(0) + + M_diag = ov_opset.multiply( + M_param.output(0), eye_body_param.output(0) + ).output(0) + trace_axes = ov_opset.concat([minus_2, minus_1], 0).output(0) + trace = ov_opset.reduce_sum(M_diag, trace_axes, keep_dims=True).output(0) + + minus_one = ov_opset.constant(-1.0, a_ov.get_element_type()).output(0) + c_k_factor = ov_opset.divide(minus_one, k_f32).output(0) + c_k = ov_opset.multiply(c_k_factor, trace).output(0) + + c_k_I = ov_opset.multiply(c_k, eye_body_param.output(0)).output(0) + M_plus_c_k_I = ov_opset.add(M_param.output(0), c_k_I).output(0) + + M_next = ov_opset.matmul( + A_body_param.output(0), M_plus_c_k_I, False, False + ).output(0) + + cond_next = ov_opset.constant(True, Type.boolean).output(0) + + body = ov.Model( + [M_next, k_next, c_k, cond_next], + [M_param, k_param, A_body_param, eye_body_param], + ) + loop.set_function(body) + loop.set_special_body_ports([-1, 3]) + + loop.set_merged_input(M_param, a_ov, M_next) + loop.set_merged_input( + k_param, ov_opset.constant(0, Type.i32).output(0), k_next + ) + loop.set_invariant_input(A_body_param, a_ov) + loop.set_invariant_input(eye_body_param, eye_reshaped) + + out_c_k = loop.get_iter_value(c_k, -1) + + det_c_k = ov_opset.squeeze(out_c_k, trace_axes).output(0) + + N_mod_2 = ov_opset.mod( + N_node_scalar, ov_opset.constant(2, Type.i32).output(0) + ).output(0) + N_mod_2_f32 = ov_opset.convert(N_mod_2, a_ov.get_element_type()).output(0) + one = ov_opset.constant(1.0, a_ov.get_element_type()).output(0) + two = ov_opset.constant(2.0, a_ov.get_element_type()).output(0) + sign = ov_opset.subtract( + one, ov_opset.multiply(two, N_mod_2_f32).output(0) + ).output(0) + + det = ov_opset.multiply(det_c_k, sign).output(0) + + if original_type == Type.f64: + det = ov_opset.convert(det, Type.f64).output(0) + + return OpenVINOKerasTensor(det) def eig(a):